Windows에서 jax 사용하는 방법
다음 github 저장소를 참고하면 된다.
2022-10-28 시점으로는 0.3.20이 최신 버전인듯 하다.
지원 버전 일람
cpu 버전 설치 명령어
pip install "jax[cpu]===0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
cuda 버전 설치 명령어
pip install "jax[cuda11_cudnn82]==0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
예제 코드
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
print(jnp.max(x))
print(jax.devices())
아래는 결과. 이제 jax를 활용해봐야겠다.
예제 결과
Enjoy Reading This Article?
Here are some more articles you might like to read next: