python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > PyTorch Distributed Data Parallel

PyTorch Distributed Data Parallel使用详解

作者:光火

这篇文章主要为大家介绍了PyTorch Distributed Data Parallel使用详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

DDP

Distributed Data Parallel 简称 DDP,是 PyTorch 框架下一种适用于单机多卡、多机多卡任务的数据并行方式。由于其良好的执行效率及广泛的显卡支持,熟练掌握 DDP 已经成为深度学习从业者所必备的技能之一。本文结合具体代码,详细地说明了 DDP 在项目中的使用方式。读者按照本文所给的范例,只需稍经调试,即可实现 DDP 的整套流程。

概念辨析

具体讲解 DDP 之前,我们先了解了解它和 Data Parallel (DP) 之间的区别。DP 同样是 PyTorch 常见的多 GPU 并行方式之一,且它的实现非常简洁:

# 函数定义
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
'''
module : 模型
device_ids : 参与训练的 GPU 列表
output_device : 指定输出的 GPU, 通常省略, 即默认使用索引为 0 的显卡
'''
# 程序模板
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)

基本原理及固有缺陷:在 Data Parallel 模式下,数据会被自动切分,加载到 GPU。同时,模型也将拷贝至各个 GPU 进行正向传播。在多个进程之间,会有一个进程充当 master 节点,负责收集各张显卡积累的梯度,并据此更新参数,再统一发送至其他显卡。因此整体而言,master 节点承担了更多的计算与通信任务,容易造成网络堵塞,影响训练速度。

常见问题及解决方案:Data Parallel 要求模型必须在 device_ids[0] 拥有参数及缓冲区,因此当卡 0 被占用时,可以在 nn.DataParallel 之前添加如下代码:

# 按照 PIC_BUS_ID 顺序自 0 开始排列 GPU 设备
os.environ['CUDA_DEVICE_ORDER'] = 'PIC_BUS_ID'
# 设置当前使用的 GPU 为 2、3 号设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3'

如此,device_ids[0] 将被默认为 2 号卡,device_ids[1] 则对应 3 号卡

相较于 DP, Distributed Data Parallel 的实现要复杂得多,但是它的优势也非常明显:

常用术语

代码实现

Distributed Data Parallel 可以通过 Python 的 torch.distributed.launch 启动器,在命令行分布式地执行 Python 文件。执行过程中,启动器会将当前进程(其实就是 GPU)的 index 通过参数传递给 Python,而我们可以利用如下方式获取这个 index:

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,
                    metavar='N', help='Local process rank.')
args = parser.parse_args()
# print(args.local_rank)
# local_rank 表示本地进程序号

随后,初始化进程组。对于在 GPU 执行的任务,建议选择 nccl (由 NVIDIA 推出) 作为通信后端。对于在 CPU 执行的任务,建议选择 gloo (由 Facebook 推出) 作为通信后端。倘若不传入 init_method,则默认为 env://,表示自环境变量读取分布式信息

dist.init_process_group(backend='nccl', init_method='env://')
# 初始化进程组之后, 通常会执行这两行代码
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
# 后续的 model = model.to(device), tensor.cuda(device)
# 对应的都是这里由 args.local_rank 初始化得到的 device

数据部分,使用 Distributed Sampler 划分数据集,并将 sampler 传入 data loader。需要注意的是,此时在 data loader 中不能指定 shuffle 为 True,否则会报错 (sampler 已具备随机打乱功能)

dev_sampler = data.DistributedSampler(dev_data_set)
train_sampler = data.DistributedSampler(train_data_set)
dev_loader = data.DataLoader(dev_data_set, batch_size=dev_batch_size, 
                             shuffle=False, sampler=dev_sampler)
train_loader = data.DataLoader(train_data_set, batch_size=train_batch_size, 
                               shuffle=False, sampler=train_sampler)

模型部分,首先将将模型送至 device,即对应的 GPU 上,再使用 Distributed Data Parallel 包装模型(顺序颠倒会报错)

model = model.to(device)
model = nn.parallel.DistributedDataParallel(
    model, device_ids=[args.local_rank], output_device=args.local_rank
)

Distributed Data Parallel 模式下,保存模型应使用 net.module.state_dict(),而非 net.state_dict()。且无论是保存模型,还是 LOGGER 打印,只对 local_rank 为 0 的进程操作即可,因此代码中会有很多 args.local_rank == 0 的判断

if args.local_rank == 0:
    LOGGER.info(f'saving latest model: {output_path}')
    torch.save({'model': model.module.state_dict(), 
                'optimizer': None, 'epoch': epoch, 'best-f1': best_f1}, 
               open(os.path.join(output_path, 'latest_model_{}.pth'.format(fold)), 'wb'))

利用 torch.load 加载模型时,设置 map_location=device,否则卡 0 会承担更多的开销

load_model = torch.load(best_path, map_location=device)
model.load_state_dict(load_model['model'])

启动方式

torch.distributed.launch

# 此处 --nproc_per_node 4 的含义是 server 有 4 张显卡
python torch.distributed.launch --nproc_per_node 4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
python torch.distributed.launch --nproc_per_node 4 train.py

代码中,只需将 Argument Parser 相关的部分替换为

local_rank = int(os.environ['LOCAL_RANK'])

然后将 args.local_rank 全部改为 local_rank 即可

启动命令

# 单机多卡训练时, 可以不指定 nnodes
torchrun --nnodes=1 --nproc_per_node=4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
nohup torchrun --nnodes=1 --nproc_per_node=4 train.py > nohup.out &

以上就是PyTorch Distributed Data Parallel使用详解的详细内容,更多关于PyTorch Distributed Data Parallel的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文