if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
... (loads the model and the vocabulary)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
理解:
四个进程(0,1, 2 ,3)多卡训练模型时都是同步并行的,但是在读取数据,数据预处理等操作是不需要并行做的。一般只需要主进程(local_rank = 0)进行这些操作。
在执行到第一个if语句,其他进程(local_rank != 0)会被阻塞。主进程执行后面操作。直到执行第二个if语句时,主进程也被阻塞。当所有进程都被阻塞时,torch.distributed.barrier()会释放所有进程。
Distributed communication package - torch.distributed — PyTorch 1.11.0 documentation