torch显卡操作

import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

device0 = torch.device('cuda:0')
device1 = torch.device('cuda:1')
data = torch.rand(300,300)

print(data)
print(torch.cuda.device_count()) # 显卡数量 2
print(torch.cuda.get_device_name()) # 设备名 NVIDIA GeForce RTX 3090
print(torch.cuda.current_device()) # 当前设备编号 0
data0 = data.to(device0)
data1 = data.to(device1)
print(data0.device) # cuda:0
print(data1.device) # cuda:1
time.sleep(60)
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容