DistributedDataParallel config
Import
from torch.utils.data.distributed import DistributedSampler
# import torch.distributed as dist
main
if __name__ == '__main__':
...
parser.add_argument("--local_rank", type=int, default=0)
...
train
def train(args):
# 初始化,设置通信方式
torch.distributed.init_process_group(backend="nccl")
# 设置当前进程的GPU. local_rank为设备编号
local_rank = args.local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
...
dataset = ...
# 由于每个进程独立进行,需要额外设置sampler
sampler = DistributedSampler(dataset)
# batch_siz设置. 总的batch_size为batch_size*num_gpu
# 使用sampler之后,不能设置shuffle为true
train_loader = DataLoader(train_data, batch_size=6,
num_workers=8,sampler = sampler)
...
# model在各个进程的初始化参数需要相同,可以设置相同的种子
model = ...
# model移入当前进程的GPU
model.to(device)
model = torch.nn.parallel.DistributedDataParallel(\
model,device_ids=[local_rank],\
output_device=local_rank,find_unused_parameters=True)
...
for epoch in range(args.start_ep,20):
# 设置这个地方可以使每个epoch的batch随机
sampler.set_epoch(epoch)
...
for step, batch_x in enumerate(train_loader):
batch_x.to(device)
...
同步输出
使用DDP时每个进程都会输出信息,eg: print, log,并且输出可能不一致,要同步各个进程之间的信息,统一输出,可以用以下代码
if step % 100 == 0:
# 同步数据。只能同步tensor类型的数据
# 同步数据时默认对各个进程的数据求和
torch.distributed.all_reduce(loss)# 同步loss
...
# 现在各个进程之间同步数据之后,再选择在一个进程里进行输出、保存
if torch.distributed.get_rank()==0:
...
else:
pass
Warning
如果出现这个错误,是因为GPU上有不参与计算loss的Variable。如果找不到这些参数,可以在计算loss之后,反传梯度,打印每个parameters()的.grad,如果为None,就是没有参与计算的参数。
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can ena
ble unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participat
e in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` fun
ction. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).