Windows에서 jax 사용하는 방법

다음 github 저장소를 참고하면 된다.

2022-10-28 시점으로는 0.3.20이 최신 버전인듯 하다.

지원 버전 일람

cpu 버전 설치 명령어

pip install "jax[cpu]===0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

cuda 버전 설치 명령어

pip install "jax[cuda11_cudnn82]==0.3.20" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

예제 코드

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
print(jnp.max(x))
print(jax.devices())

아래는 결과. 이제 jax를 활용해봐야겠다.

예제 결과



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • 2025년 회고
  • 2024년 회고
  • Deep Neural Crossover 리뷰