1.环境准备
- 操作系统:
Ubuntu 20.04.1 LTS
- CPU:
128C
- Memroy:
1T
- Disk:
10T
- GPU:
NVIDIA A800-SXM4-80GB
* 8 - 驱动:
Driver Version: 530.30.02
CUDA Version: 12.1
2.下载
-
代码下载
git clone https://github.com/xai-org/grok-1.git cd grok-1 pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple mkdir checkpoints/ckpt-0
-
模型下载
huggingface-cli 下载到99%总是卡住,直接使用wget下载
sh urls.sh #见下方 nohup sh download.sh > log 2>& 1 &
-
url.sh
#/bin/bash echo "#/bin/bash" > download.sh for i in {0..769} do seq="tensor0000"$i"_000" if [ $i -ge 10 ] then seq="tensor000"$i"_000" fi if [ $i -ge 100 ] then seq="tensor00"$i"_000" fi echo "wget -O checkpoints/ckpt-0/$seq https://hf-mirror.com/xai-org/grok-1/resolve/main/ckpt-0/$seq?download=true " >> download.sh done
3. 运行使用
-
测试jax是否可用
import torch,jax print(torch.cuda.is_available()); print(jax.devices())
-
修改run.py
64 gen = inference_runner.run() 65 66 #inp = "The answer to life the universe and everything is of course" 67 inp = "write quick sort with golang" # 输入自定义的prompt 68 # print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) 69 print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=300, temperature=0.01)) #max_len修改大一点 70 71 if __name__ == "__main__": 72 logging.basicConfig(level=logging.INFO) 73 main()
prompt
: write quick sort with golangoutput
:
-
GPU 资源使用
4.TODO
构建成serving
5.常见问题
问题1
-
现象
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such fil e or directory WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. INFO:rank:Initializing mesh for self.local_mesh_config=(1, 8) self.between_hosts_config=(1, 1)... INFO:rank:Detected 1 devices in mesh Traceback (most recent call last): File "/data/grok-1/run.py", line 72, in <module> main() File "/data/grok-1/run.py", line 63, in main inference_runner.initialize() File "/data/grok-1/runners.py", line 282, in initialize runner.initialize( File "/data/grok-1/runners.py", line 181, in initialize self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/grok-1/runners.py", line 586, in make_mesh device_mesh = mesh_utils.create_hybrid_device_mesh( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/miniconda3/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 373, in create_hybrid_device_mesh per_granule_meshes = [create_device_mesh(mesh_shape, granule) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/miniconda3/lib/python3.12/site-packages/jax/experimental/mesh_utils.py", line 302, in create_device_mesh raise ValueError(f'Number of devices {len(devices)} must equal the product ' ValueError: Number of devices 1 must equal the product of mesh_shape (1, 8)
-
解决办法: 修改
requirements.txt
dm_haiku==0.0.12 jax[cuda12]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html numpy==1.26.4 sentencepiece==0.2.0
问题2
-
现象
2024-03-21 15:27:29.467666: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:280] failed call to cuInit: CUDA_ERROR_NOT_INITIALIZED: initialization error Traceback (most recent call last): File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 694, in backends backend = _init_backend(platform) File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 776, in _init_backend backend = registration.factory() File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 514, in factory return xla_client.make_c_api_client(plugin_name, options, None) File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jaxlib/xla_client.py", line 197, in make_c_api_client return _xla.get_c_api_client(plugin_name, options, distributed_client) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/data/grok-1/testgpu.py", line 3, in <module> print(jax.devices()) File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 887, in devices return get_backend(backend).devices() File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 821, in get_backend return _get_backend_uncached(platform) File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 801, in _get_backend_uncached bs = backends() File "/data/miniconda3/envs/grok/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in backends raise RuntimeError(err_msg) RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
解决办法: 安装cuDNN