admin 管理员组

文章数量: 887021

TensorRT

文章目录

  • 1 问题:TensorRT暂时未实现einsum算子
  • 2 使用普通算子替代einsum操作的示例
    • 2.1 替换原理
    • 2.2 转换示例
      • 2.2.1 对torch.einsum("nctw,cd->ndtw",(a,b))的替代
      • 2.2.2 对torch.einsum('nkctv,kvw->nctw',(a,b))的替代
      • 2.2.3 对torch.einsum("bhnm,bdhm->bdhn",(a,b))的替代
      • 2.2.4 对torch.einsum("nkctv,kcvw->nctw",(a,b))的替代
  • 3 在转换pytorch->onnx->TensorRT模型中使用该替换方法
  • 参考链接

1 问题:TensorRT暂时未实现einsum算子

在ST-GCN中使用了爱因斯坦求和算子torch.einsum,

def forward(self, x, A):assert A.size(0) == self.kernel_sizex = self.conv(x)n, kc, t, v = x.size()x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)x = torch.einsum('nkctv,kvw->nctw', (x, A))return x.contiguous(), A

einsum爱因斯坦求和约定确实易用,但是即使是目前TensorRT最新版本TensorRT8也还未支持einsum算子,如果要使用TensorRT上部署ST-GCN网络,则必须增加对einsum算子的支持,

  • 第一种方法,则是使用TensorRT自定义插件的API,自己编写代码扩展插件,增加对einsum的支持;
  • 第二种方法,则是在pytorch中模型网络中使用常规的pytorch算子替代torch.einsum算子,使得onnx模型转换为TensorRT模型的过程中不会出现einsum找不到的问题;

第一种方法相比第二种方法来说难度较大,需要对TensorRT创建自定义插件的流程以及相关API比较熟悉,如果第二种方法可行将在很大程度降低在TensorRT C++层编程的工作量,直接在python层转换onnx模型完成全部的工作。

2 使用普通算子替代einsum操作的示例

2.1 替换原理

关于对torch.einsum算子的替换原理,这里就以torch.einsum(‘nkctv,kvw->nctw’, (a, b))进行说明,如果对einsum不了解,可以先查看这篇文章:,

torch.einsum('nkctv,kvw->nctw', (a, b))

算子的意思是两个输入tensor:

  • a :tensor(n,k,c,t,v)
  • b:tensor(k,v,w)

两者通过batch 矩阵乘法得到输出向量tensor(n,c,t,w)。

在整个运算过程中总共有n、k、c、t、v、w六个维度参与运算,而最后的输出结果只有n、c、t、w四个维度,说明需要求和的维度是k、v,如果我们在最后的求和过程中将最后的结果先变成所有维度的tensor即tensor(n,k,c,t,v,w),然后再在k、v两个维度求和,那么tensor不就变成所需要的tensor(n,c,t,w)了?!

那么如何将输出tensor扩展为tensor(n,k,c,t,v,w)?

首先将a reshape 为tensor(n,k,c,t,v,1),然后将b reshape为tensor(1,k,1,1,v,w),然后将a*b,那么最后的tensor就变成了tensor(n,k,c,t,v,w)了,之后只需将k、v两个维度求和得出最后的结果tensor(n,c,t,w)

这种替换之后的计算结果是否正确?

根据之后的示例计算的结果比较上来看,结果没有什么问题。

这种方法经测试,不管在训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题,新的替换方式请参考:/

2.2 转换示例

2.2.1 对torch.einsum(“nctw,cd->ndtw”,(a,b))的替代

import torchif __name__ == '__main__':n_dim = 2c_dim = 3t_dim = 4w_dim = 5d_dim = 6a = torch.rand(n_dim,c_dim,t_dim,w_dim).cuda()b = torch.rand(c_dim,d_dim).cuda()# 1 使用einsum算子a_b_einsum = torch.einsum("nctw,cd->ndtw",(a,b))print(a_b_einsum.shape)print(a_b_einsum)# 2 替代方法d = a.reshape(n_dim,c_dim,t_dim,w_dim,1)e = b.reshape(1,c_dim,1,1,d_dim)d_e = d * eg = torch.sum(d_e,dim=1)g = g.transpose(1, 3).transpose(2,3).contiguous()print(g.shape)print(g)

运行结果:

torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],[1.2332, 0.6786, 0.7968, 0.8069, 1.1939],[0.9592, 1.7275, 1.8042, 0.9783, 1.6064],[0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],[[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],[0.1969, 0.2989, 0.1415, 0.6890, 0.1376],[0.8199, 0.8984, 0.7805, 0.3741, 0.8887],[0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],[[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],[1.2442, 0.6108, 0.7671, 0.8894, 1.1784],[0.9610, 1.7029, 1.9283, 1.0882, 1.7516],[0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],[[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],[0.4787, 0.3499, 0.3109, 0.6130, 0.4321],[0.7142, 0.9846, 0.9704, 0.5076, 0.9691],[0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],[[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],[0.9637, 0.3817, 0.5576, 0.6986, 0.8936],[0.6641, 1.2121, 1.5420, 0.9152, 1.4114],[0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],[[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],[1.3506, 0.6587, 0.8732, 0.5711, 1.3413],[0.6997, 1.5718, 1.6919, 0.9340, 1.3964],[0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],[[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],[1.3980, 1.6822, 0.8348, 0.6170, 0.5092],[1.1900, 1.7543, 0.4870, 1.6122, 0.4977],[1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],[[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],[1.0389, 1.0299, 0.6792, 0.1813, 0.0616],[0.6002, 0.9196, 0.1439, 0.8148, 0.3918],[0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],[[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],[1.4199, 1.9031, 0.8216, 0.5613, 0.4442],[1.4446, 1.9075, 0.6006, 1.7294, 0.4762],[1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],[[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],[0.9637, 1.0766, 0.6016, 0.2726, 0.1767],[0.7109, 1.0313, 0.2401, 0.9286, 0.3498],[0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],[[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],[1.0185, 1.5859, 0.5563, 0.3674, 0.2800],[1.3060, 1.5378, 0.5619, 1.3758, 0.3097],[0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],[[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],[1.0979, 1.3947, 0.6303, 0.6454, 0.5835],[1.0327, 1.5530, 0.4744, 1.4485, 0.3857],[0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')
torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],[1.2332, 0.6786, 0.7968, 0.8069, 1.1939],[0.9592, 1.7275, 1.8042, 0.9783, 1.6064],[0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],[[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],[0.1969, 0.2989, 0.1415, 0.6890, 0.1376],[0.8199, 0.8984, 0.7805, 0.3741, 0.8887],[0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],[[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],[1.2442, 0.6108, 0.7671, 0.8894, 1.1784],[0.9610, 1.7029, 1.9283, 1.0882, 1.7516],[0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],[[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],[0.4787, 0.3499, 0.3109, 0.6130, 0.4321],[0.7142, 0.9846, 0.9704, 0.5076, 0.9691],[0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],[[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],[0.9637, 0.3817, 0.5576, 0.6986, 0.8936],[0.6641, 1.2121, 1.5420, 0.9152, 1.4114],[0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],[[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],[1.3506, 0.6587, 0.8732, 0.5711, 1.3413],[0.6997, 1.5718, 1.6919, 0.9340, 1.3964],[0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],[[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],[1.3980, 1.6822, 0.8348, 0.6170, 0.5092],[1.1900, 1.7543, 0.4870, 1.6122, 0.4977],[1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],[[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],[1.0389, 1.0299, 0.6792, 0.1813, 0.0616],[0.6002, 0.9196, 0.1439, 0.8148, 0.3918],[0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],[[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],[1.4199, 1.9031, 0.8216, 0.5613, 0.4442],[1.4446, 1.9075, 0.6006, 1.7294, 0.4762],[1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],[[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],[0.9637, 1.0766, 0.6016, 0.2726, 0.1767],[0.7109, 1.0313, 0.2401, 0.9286, 0.3498],[0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],[[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],[1.0185, 1.5859, 0.5563, 0.3674, 0.2800],[1.3060, 1.5378, 0.5619, 1.3758, 0.3097],[0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],[[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],[1.0979, 1.3947, 0.6303, 0.6454, 0.5835],[1.0327, 1.5530, 0.4744, 1.4485, 0.3857],[0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')

2.2.2 对torch.einsum(‘nkctv,kvw->nctw’,(a,b))的替代

import torchif __name__ == '__main__':n_dim = 2k_dim = 3c_dim = 4t_dim = 5v_dim = 6w_dim = 7a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()b = torch.rand(k_dim,v_dim,w_dim).cuda()# 1 使用einsum算子a_b_einsum = torch.einsum("nkctv,kvw->nctw",(a,b))print(a_b_einsum.shape)print(a_b_einsum)# 2 替代方法d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)e = b.reshape(1,k_dim,1,1,v_dim,w_dim)d_e = d * eg = d_e.sum(dim=4)g = g.sum(dim=1)print(g.shape)print(g)

运行结果:

torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],[3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],[2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],[3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],[4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],[[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],[3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],[3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],[3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],[3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],[[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],[2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],[3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],[4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],[2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],[[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],[2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],[3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],[3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],[3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],[[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],[4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],[2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],[4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],[3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],[[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],[4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],[3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],[3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],[3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],[[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],[4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],[4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],[3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],[4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],[[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],[3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],[3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],[3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],[5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],[3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],[2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],[3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],[4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],[[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],[3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],[3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],[3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],[3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],[[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],[2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],[3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],[4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],[2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],[[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],[2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],[3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],[3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],[3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],[[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],[4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],[2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],[4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],[3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],[[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],[4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],[3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],[3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],[3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],[[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],[4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],[4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],[3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],[4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],[[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],[3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],[3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],[3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],[5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],device='cuda:0')

2.2.3 对torch.einsum(“bhnm,bdhm->bdhn”,(a,b))的替代

import torchif __name__ == '__main__':b_dim = 2h_dim = 3n_dim = 4m_dim = 5d_dim = 6a = torch.rand(b_dim,h_dim,n_dim,m_dim).cuda()b = torch.rand(b_dim,d_dim,h_dim,m_dim).cuda()# 1 使用einsum算子a_b_einsum = torch.einsum("bhnm,bdhm->bdhn",(a,b))print(a_b_einsum.shape)print(a_b_einsum)# 2 替代方法d = a.reshape(b_dim,1,h_dim,n_dim,m_dim)e = b.reshape(b_dim,d_dim,h_dim,1,m_dim)d_e = d * eg = torch.sum(d_e,dim=-1)print(g.shape)print(g)

运行结果:

torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],[0.4858, 0.6758, 0.7654, 1.5621],[1.0485, 0.9274, 1.7848, 0.9245]],[[0.8395, 1.2818, 1.6657, 0.6278],[0.7742, 0.9645, 1.0017, 2.3575],[1.3290, 1.0029, 2.0633, 1.4623]],[[0.5437, 0.8261, 1.6291, 0.8061],[0.9383, 1.8158, 0.9913, 3.1291],[1.3296, 1.1673, 2.0923, 1.1132]],[[0.6604, 0.9283, 1.3009, 0.7243],[0.6694, 0.8136, 0.6781, 1.4160],[0.5809, 0.3698, 0.8549, 0.7219]],[[0.7077, 1.1118, 1.3823, 0.5207],[0.8915, 1.0434, 1.0218, 2.4020],[1.0928, 0.7864, 1.8233, 1.3614]],[[0.8477, 1.2199, 1.2862, 1.1010],[0.9283, 1.9503, 0.8049, 3.2380],[1.1121, 0.8291, 1.5418, 1.0480]]],[[[1.4059, 1.2436, 0.9244, 1.3208],[0.6139, 0.8953, 1.3918, 0.3312],[0.3111, 0.7085, 0.8762, 1.3002]],[[2.0954, 1.4036, 1.4653, 1.7839],[0.5616, 0.6855, 1.1779, 0.4554],[1.1214, 0.9399, 1.1625, 1.5147]],[[1.5334, 1.1275, 1.0018, 1.3577],[0.7086, 1.0112, 1.6142, 0.5294],[1.0993, 1.0357, 1.3549, 1.8132]],[[1.8053, 1.2803, 1.0696, 1.4491],[0.9126, 1.2431, 1.9852, 0.5952],[1.3642, 1.8556, 1.7529, 2.6744]],[[1.4492, 1.1429, 1.0478, 1.2675],[0.7594, 1.0712, 1.7349, 0.5283],[1.4567, 1.9925, 1.7371, 2.5965]],[[1.8359, 1.0199, 0.9771, 1.4224],[1.5275, 1.6111, 2.1722, 0.8499],[0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')
torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],[0.4858, 0.6758, 0.7654, 1.5621],[1.0485, 0.9274, 1.7848, 0.9245]],[[0.8395, 1.2818, 1.6657, 0.6278],[0.7742, 0.9645, 1.0017, 2.3575],[1.3290, 1.0029, 2.0633, 1.4623]],[[0.5437, 0.8261, 1.6291, 0.8061],[0.9383, 1.8158, 0.9913, 3.1291],[1.3296, 1.1673, 2.0923, 1.1132]],[[0.6604, 0.9283, 1.3009, 0.7243],[0.6694, 0.8136, 0.6781, 1.4160],[0.5809, 0.3698, 0.8549, 0.7219]],[[0.7077, 1.1118, 1.3823, 0.5207],[0.8915, 1.0434, 1.0218, 2.4020],[1.0928, 0.7864, 1.8233, 1.3614]],[[0.8477, 1.2199, 1.2862, 1.1010],[0.9283, 1.9503, 0.8049, 3.2380],[1.1121, 0.8291, 1.5418, 1.0480]]],[[[1.4059, 1.2436, 0.9244, 1.3208],[0.6139, 0.8953, 1.3918, 0.3312],[0.3111, 0.7085, 0.8762, 1.3002]],[[2.0954, 1.4036, 1.4653, 1.7839],[0.5616, 0.6855, 1.1779, 0.4554],[1.1214, 0.9399, 1.1625, 1.5147]],[[1.5334, 1.1275, 1.0018, 1.3577],[0.7086, 1.0112, 1.6142, 0.5294],[1.0993, 1.0357, 1.3549, 1.8132]],[[1.8053, 1.2803, 1.0696, 1.4491],[0.9126, 1.2431, 1.9852, 0.5952],[1.3642, 1.8556, 1.7529, 2.6744]],[[1.4492, 1.1429, 1.0478, 1.2675],[0.7594, 1.0712, 1.7349, 0.5283],[1.4567, 1.9925, 1.7371, 2.5965]],[[1.8359, 1.0199, 0.9771, 1.4224],[1.5275, 1.6111, 2.1722, 0.8499],[0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')

2.2.4 对torch.einsum(“nkctv,kcvw->nctw”,(a,b))的替代

import torchif __name__ == '__main__':n_dim = 2k_dim = 3c_dim = 4t_dim = 5v_dim = 6w_dim = 7a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()b = torch.rand(k_dim,c_dim,v_dim,w_dim).cuda()# 1 使用einsum算子a_b_einsum = torch.einsum("nkctv,kcvw->nctw",(a,b))print(a_b_einsum.shape)print(a_b_einsum)# 2 替代方法d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)e = b.reshape(1,k_dim,c_dim,1,v_dim,w_dim)d_e = d * eg = d_e.sum(dim=4)g = g.sum(dim=1)print(g.shape)print(g)

运行结果:

torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],[4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],[4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],[4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],[4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],[[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],[5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],[4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],[4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],[3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],[[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],[3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],[3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],[3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],[4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],[[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],[3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],[4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],[5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],[4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],[[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],[4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],[5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],[4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],[4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],[[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],[3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],[4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],[4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],[4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],[[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],[3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],[2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],[3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],[3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],[[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],[3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],[4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],[3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],[3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],[4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],[4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],[4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],[4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],[[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],[5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],[4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],[4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],[3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],[[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],[3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],[3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],[3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],[4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],[[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],[3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],[4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],[5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],[4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],[[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],[4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],[5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],[4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],[4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],[[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],[3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],[4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],[4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],[4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],[[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],[3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],[2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],[3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],[3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],[[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],[3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],[4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],[3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],[3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],device='cuda:0')

3 在转换pytorch->onnx->TensorRT模型中使用该替换方法

如果模型转换的过程中采用的使用pytorch->onnx->TensorRT的转换路线,那么上述算子替换操作需要在pytorch->onnx模型的过程中就完成替换,这样在onnx模型中就没有einsum算子,那么在onnx->TensorRT模型的过程中自然就不会出现找不到einsum的错误了。

参考链接

如果有兴趣,可以访问我的个站:/,更多干货!

本文标签: TensorRT