admin 管理员组

文章数量: 887021

named

从名字就可以看出两者的区别
一个是模块信息,一个是参数,但还是想要更直观一点。
文字的说法可以看这里

这里展示的网络模型的代码来自
这里进行的修改就是打印出了对应的两个不同的信息:

import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import functional, layer, surrogate, neuron
from torchvision import transformsclass Cifar10Net(nn.Module):def __init__(self, T=8, v_threshold=1.0, v_reset=0.0, tau=2.0, surrogate_function=surrogate.ATan()):super().__init__()self.train_times = 0self.epochs = 0self.max_test_acccuracy = 0self.T = Tself.static_conv = nn.Sequential(nn.Conv2d(3, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),)self.conv = nn.Sequential(neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.MaxPool2d(2, 2),  # 16 * 16nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(256),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.MaxPool2d(2, 2)  # 8 * 8)self.fc = nn.Sequential(nn.Flatten(),layer.Dropout(0.5),nn.Linear(256 * 8 * 8, 128 * 4 * 4, bias=False),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),nn.Linear(128 * 4 * 4, 100, bias=False),neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True))self.boost = nn.AvgPool1d(10, 10)def forward(self, x):x = self.static_conv(x)out_spikes_counter = self.boost(self.fc(self.conv(x)).unsqueeze(1)).squeeze(1)for _ in range(1, self.T):out_spikes_counter += self.boost(self.fc(self.conv(x)).unsqueeze(1)).squeeze(1)return out_spikes_counterif __name__ == "__main__":net = Cifar10Net()# print(net)print('named_modules:')for name, module in net.named_modules():print('name:{}, module {}'.format(name, module))print('#####################################################')print('named_parameters:')for name, param in net.named_parameters():print('name:{}, param {}'.format(name, param))

最后输出的结果大概时以下两种不同的

上面是一些模块的详细信息。
下面就是固定的权重


对模块的信息进行判断,可能还有中间变量可以存取,比如这里判断是否存在某种属性:

     for name, module in net.named_modules():if hasattr(module, 'monitor'):spike_times[name] = 0

权重可以用来取出来采取不同的操作,比如优化:

    BN_list = ['static_conv.1', 'conv.2', 'conv.5', 'conv.9', 'conv.12', 'conv.15']for name, param in net.named_parameters():if any(BN_name in name for BN_name in BN_list):bn_params += [param]ttl_cnt += param.numel()else:weight_params += [param]w_cnt += param.numel()ttl_cnt += param.numel()

本文标签: named