windows에서 jax 사용하기
아래 github 저장소를 참고하면 된다.
!https://github.com/cloudhan/jax-windows-builder
2022-10-28 시점으로는 0.3.20이 최신 버전인듯 하다.
cpu 버전 설치 명령어
pip install "jax[cpu]===0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
cuda 버전 설치
pip install "jax[cuda11_cudnn82]==0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
예제 코드
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
print(jnp.max(x))
print(jax.devices())
Leave a comment