windows에서 jax 사용하기

아래 github 저장소를 참고하면 된다.

!https://github.com/cloudhan/jax-windows-builder

2022-10-28 시점으로는 0.3.20이 최신 버전인듯 하다.

image

cpu 버전 설치 명령어

pip install "jax[cpu]===0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

cuda 버전 설치

pip install "jax[cuda11_cudnn82]==0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

예제 코드

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
print(jnp.max(x))
print(jax.devices())

image

Leave a comment