python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch中fuse_modules

pytorch中fuse_modules源码解读

作者:weixin_45919003

这篇文章主要介绍了pytorch中fuse_modules,fuse_known_modules将给定的模块列表mod_list中的一些常见模块进行融合,返回融合后的模块列表,本文通过实例代码详细讲解,需要的朋友可以参考下

1. 官方代码

FUSE_MODULES
TORCH.AO.QUANTIZATION.FUSE_MODULES的源代码

2. fuse_modules源码解读

仅融合以下序列:

网络中所有其他序列保持不变,对于上述序列,用融合的模块替换列表中的第一项,用identity替换其余模块。

fuse_modules

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):

fuse_known_modules

将给定的模块列表mod_list中的一些常见模块进行融合,返回融合后的模块列表。融合后的模块可以有效地减少模型计算量和内存占用,从而提高模型的计算效率。

参数

def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""Returns a list of modules that fuses the operations specified
     in the input module list.
    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, bn
    linear, relu
    For these sequences, the first element in the output module list performs
    the fused operation. The rest of the elements are set to nn.Identity()
    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError("Cannot fuse modules: {}".format(types))
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
    # Move pre forward hooks of the base module to resulting fused module
    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
        fused.register_forward_pre_hook(pre_hook_fn)
        del mod_list[0]._forward_pre_hooks[handle_id]
    # Move post forward hooks of the last module to resulting fused module
    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
        fused.register_forward_hook(hook_fn)
        del mod_list[-1]._forward_hooks[handle_id]
    new_mod[0] = fused
    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity
    return new_mod
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
}

fuse_conv_bn

将给定的conv和bn模块融合并返回融合后的模块。

在此函数中构建了一个fused_module_class_map字典,用于指定模块类型与对应的融合模块类型之间的映射关系。

如果其类型在fused_module_class_map字典中有对应的融合模块类型,则将这些模块融合为一个新的模块(ConvBn2d),如果没有对应的融合模块类型,则不对其进行融合处理。

def fuse_conv_bn(is_qat, conv, bn):
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module_class_map = {
        nn.Conv1d: nni.ConvBn1d,
        nn.Conv2d: nni.ConvBn2d,
        nn.Conv3d: nni.ConvBn3d,
    }
    if is_qat:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        fused_module_class = fused_module_class_map.get((type(conv)), None)
        if fused_module_class is not None:
            return fused_module_class(conv, bn)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
    else:
        return nn.utils.fuse_conv_bn_eval(conv, bn)

返回调用的 fuse_conv_bn_eval(conv, bn) 函数如下

返回一个新的融合模块,该模块包含了卷积层和BN层的参数,并将其组合成一个新的运算。

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)
    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
    return fused_conv

3. fuse_modules实际测试

3.1 modules_to_fuse参数的使用方法

1. 此参数的列表可以包含多个需要融合的组合,子模块列表也可以,使用方法一

方法一:

modules_to_fuse = [ [‘conv1’, ‘bn1’, ‘relu1’], [‘submodule.conv’, ‘submodule.relu’]]

融合ResNet18中layer1的conv和bn层如下:

print('\n Before fusion \n\n', r18_o.layer1)
r18_o.eval()
r18 = torch.quantization.fuse_modules(
    r18_o,
    [['conv1', 'bn1', 'relu'],
     ['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],
     ['layer1.0.conv2', 'layer1.0.bn2'],
     ['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],
     ['layer1.1.conv2', 'layer1.1.bn2']]
)
print('\n After fusion\n\n', r18.layer1)

结果:

ResNet18融合前:(仅显示ResNet18中layer1的网络结构)

ResNet18融合后

此融合只将Conv2d和BN层进行融合,从上面对比可以看到融合后的 (bn) 变成了 identity(),(conv) 中的Conv2d是原本Conv2d和BN融合的。

2. 如果要融合的module被Sequential封装了,可使用方法二

方法二:

torch.quantization.fuse_modules(m, [‘0’, ‘1’, ‘2’], inplace=True)

1. 使用方法二对ResNet18中模块进行融合操作,融合代码如下:

def fuse_model(self):
    for m in self.modules():
        if type(m) == BasicBlock:
            torch.quantization.fuse_modules(m, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True)

此处代码是仿pytorch官方写MobileNetV2模块融合,这部分代码写在 class ResNet(nn.Module) 中,后面融合直接使用model.fuse_model(),得到的方法二融合ResNet18结果如下:

此处是分别对(conv2d、bn、relu)和(conv2d、bn)进行融合融合

2. 使用方法二对MobileNetv2中模块进行融合操作

def fuse_model(self):
    for m in self.modules():
        if type(m) == ConvBNReLU:
            torch.quantization.fuse_modacules(m, ['0', '1', '2'], inplace=True)
        if type(m) == InvertedResidual:
            for idx in range(len(m.conv)):
                if type(m.conv[idx]) == nn.Conv2d:
                    torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

结果

MobileNetv2融合前(下面结果展示的是第一个残差模块,因此没有第一个1x1的卷积)

MobileNetv2融合后

从此对比可以看到,融合前的conv2d、bn、relu融合成了ConvRelu2d(Conv2d,ReLU),这里面的Conv2d是融合前的Conv2d和BN融合的。

到此这篇关于pytorch中fuse_modules的文章就介绍到这了,更多相关pytorch中fuse_modules内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

阅读全文