本文参考:5-剪枝后模型参数赋值_哔哩哔哩_bilibiliz
https://github.com/foolwood/pytorch-slimming
论文:Learning Efficient Convolutional Networks through Network Slimming
(1)卷积后得到多个特征图(channel=64, 128, 256…),这些图不一定都重要,所以量化计算特征图的重要性
(2)训练模型的时候需要加入一些策略,让权重参数有明显的大小之分,从而筛选重要的特征图
Channel scaling factors里面的数值为特征图的打分,直观理解为分值大的特征图需要保留,分值小的特征图可以去掉。
Network slimming ,利用BN层中的缩放因子Ƴ
BN的理论支持:
,使得数据为(0,1)正态分布。
整体感觉是一个归一化操作,但是BN中需要额外引入两个可训练的参数:Ƴ和β
BatchNorm的本质:
(1)BN要做的就是把越来越偏离的分布给拉回来
(2)再重新规范化到均值为0方差为1的标准正态分布
(3)这样能够使得激活函数在数值层面更敏感,训练更快。
(4)产生的问题:经过BN之后,把数值分布强制在了非线性函数的线性区域中。
针对第(3)点解释:
在激活函数中,两边处于饱和区域不敏感,接近于0位置非饱和处于敏感区域。
针对第(4)点解释:
BN将数据强制压缩到中间红色区域的线性部分,F(x)只做仿射变化,F=sigmoid,多个仿射变化的叠加仍然是仿射变化,添加再多隐藏层与单层神经网络是等价的。
所以,BN需要保证一些非线性,对规范后的结果再进行变化。
添加两个参数后重新训练:
,这两个参数是网络训练过程中得到的,而不是超参给的。
该公式相当于BN的逆变换,
相当于对正态分布进行一些改变,拉动一下,变一下形状,做适当的还原。
Ƴ值越大越重要,那么该特征图调整的幅度越大,说明该特征图越重要。
使用L1正则化对参数进行稀疏操作。
L1求导后为:sign(Θ),相当于稳定前进,都为1,最后学成0了
L2求导后为:Θ,相当于越来越慢,很多参数都接近0,平滑。
论文核心:
使用到的vgg模型架构:
import torch
import torch.nn as nn
import math
from torch.autograd import Variableclass vgg(nn.Module):def __init__(self, dataset='cifar10', init_weights=True, cfg=None):super(vgg, self).__init__()if cfg is None:cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]self.feature = self.make_layers(cfg, True)if dataset == 'cifar10':num_classes = 10elif dataset == 'cifar100':num_classes = 100self.classifier = nn.Linear(cfg[-1], num_classes)if init_weights:self._initialize_weights()def make_layers(self, cfg, batch_norm=False):layers = []in_channels = 3for v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)def forward(self, x):x = self.feature(x)x = nn.AvgPool2d(2)(x)x = x.view(x.size(0), -1)y = self.classifier(x)return ydef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(0.5)m.bias.data.zero_()elif isinstance(m, nn.Linear):m.weight.data.normal_(0, 0.01)m.bias.data.zero_()if __name__ == '__main__':net = vgg()x = Variable(torch.FloatTensor(16, 3, 40, 40))y = net(x)print(y.data.shape)
1、原始模型训练:
(1)BN的L1稀疏正则化:使用次梯度下降法,对BN层的权重进行再调整
(2)训练完成后主要保存原始模型的参数信息
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from vgg import vgg
import shutil
from tqdm import tqdmlearning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 3
log_interval = 100
batch_size = 100
sparsity_regularization = True
scale_sparse_rate = 0.0001checkpoint_model_path = 'checkpoint,pth.tar'
best_model_path = 'model_best.pth.tar'train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,transform=transforms.Compose([transforms.Pad(4),transforms.RandomCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)model = vgg()
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(tqdm(train_loader)):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()if sparsity_regularization:updateBN()optimizer.step()if batch_idx % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in tqdm(test_loader):data , target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, size_average=False).item()pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))return correct / float(len(test_loader.dataset))def save_checkpoint(state, is_best, filename=checkpoint_model_path):torch.save(state, filename)if is_best:shutil.copyfile(filename, best_model_path)def updateBN():for m in model.modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(scale_sparse_rate * torch.sign(m.weight.data)) # L1,使用次梯度下降best_prec = 0
for epoch in range(epochs):train(epoch)prec = test()is_best = prec > best_precbest_prec = max(prec, best_prec)save_checkpoint({'epoch': epoch + 1,'state_dict': model.state_dict(),'best_prec': best_prec,'optimizer': optimizer.state_dict()}, is_best)
2、模型剪枝
(1)剪枝过程主要分为两部分:第一部分是计算mask,第二部分是根据mask调整各层的shape
(2)BN层通道数:Conv -> BN -> ReLU -> MaxPool--à Linear,所以BN的输入维度对应Conv的输出通道数
(3)BN层总通道数:将所有BN层的通道数进行汇总
(4)BN层剪枝百分位:取总通道数的百分位得到具体的float值,大于该值的通道对应的mask置为1,否则对应的mask置为0
(5)改变权重weight:BN层抽取mask为1的通道数的值,该操作会改变BN的shape,从而上下游操作中的Conv和Linear也需要被动做出调整,对Maxpool和ReLu的通道数无影响
(6)Conv层的参数为[out_channels, in_channels, kernel_size1, kernel_size2],所以需要调整两次,先对in_channels进行调整,再对out_channels进行调整。Conv初始输入为RGB的3通道。
假如计算出的保留通道数信息为:
[48, 60, 115, 118, 175, 163, 141, 130, 259, 267, 258, 249, 225, 212, 234, 97]
Conv的输入输出变为:
In shape: 3 Out shape:48
In shape: 48 Out shape:60
In shape: 60 Out shape:115
In shape: 115 Out shape:118
……
In shape: 234 Out shape:97
(7)保存模型时,一方面把有用的参数信息保存了下来,同时剪枝后的最新的模型结构参数也保存了,方便后续再训练时构建新的模型结构
import os
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from vgg import vgg
import numpy as np
from tqdm import tqdmpercent = 0.5
batch_size = 100
raw_model_path = 'model_best.pth.tar'
save_model_path = 'prune_model.pth.tar'model = vgg()
model.cuda()
if os.path.isfile(raw_model_path):print("==> loading checkpoint '{}'".format(raw_model_path))checkpoint = torch.load(raw_model_path)start_epoch = checkpoint['epoch']best_prec = checkpoint['best_prec']model.load_state_dict(checkpoint['state_dict'])print("==> loaded checkpoint '{}'(epoch {}) Prec:{:f}".format(raw_model_path, start_epoch, best_prec) )
print(model)total = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):total += m.weight.data.shape[0]bn = torch.zeros(total)
index = 0
for m in model.modules():if isinstance(m, nn.BatchNorm2d):size = m.weight.data.shape[0]bn[index : index + size] = m.weight.data.abs().clone()index += sizey, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):if isinstance(m, nn.BatchNorm2d):weight_copy = m.weight.data.clone()mask = weight_copy.abs().gt(thre).float().cuda()pruned += mask.shape[0] - torch.sum(mask)m.weight.data.mul_(mask)m.bias.data.mul_(mask)cfg.append(int(torch.sum(mask)))cfg_mask.append(mask.clone())print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))elif isinstance(m, nn.MaxPool2d):cfg.append('M')
pruned_ratio = pruned / totalprint('pruned_ratio: {},Pre-processing Successful!'.format(pruned_ratio))# simple test model after Pre-processing prune(simple set BN scales to zeros)
def test():test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)model.eval()correct = 0for data, target in tqdm(test_loader):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))test()# make real prune
print(cfg)
new_model = vgg(cfg=cfg)
new_model.cuda()layer_id_in_cfg = 0 # cfg中的层数索引
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), new_model.modules()):if isinstance(m0, nn.BatchNorm2d):idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))m1.weight.data = m0.weight.data[idx1].clone()m1.bias.data = m0.bias.data[idx1].clone()m1.running_mean = m0.running_mean[idx1].clone()m1.running_var = m0.running_var[idx1].clone()layer_id_in_cfg += 1start_mask = end_mask.clone()if layer_id_in_cfg < len(cfg_mask): # do not change in Final FCend_mask = cfg_mask[layer_id_in_cfg]elif isinstance(m0, nn.Conv2d):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))w = m0.weight.data[:, idx0, :, :].clone()w = w[idx1, :, :, :].clone()m1.weight.data = w.clone()elif isinstance(m0, nn.Linear):idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))m1.weight.data = m0.weight.data[:, idx0].clone()torch.save({'cfg': cfg, 'state_dict': new_model.state_dict()}, save_model_path)
print(new_model)
model = new_model
test()
3、再训练
剪枝后保存的模型参数相当于训练过程中的一个checkpoint,根据新的模型结构,在此checkpoint的基础上再进行训练,直到得到满意的指标。
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from vgg import vgg
import shutil
from tqdm import tqdmlearning_rate = 0.1
momentum = 0.9
weight_decay = 1e-4
epochs = 3
log_interval = 100
batch_size = 100
sparsity_regularization = True
scale_sparse_rate = 0.0001prune_model_path = 'prune_model.pth.tar'
prune_checkpoint_path = 'pruned_checkpoint.pth.tar'
prune_best_model_path = 'pruned_model_best.pth.tar'train_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=True, download=True,transform=transforms.Compose([transforms.Pad(4),transforms.RandomCrop(32),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.CIFAR10('D:\\ai_data\\cifar10', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),batch_size=batch_size, shuffle=True)checkpoint = torch.load(prune_model_path)
model = vgg(cfg=checkpoint['cfg'])
model.cuda()
model.load_state_dict(checkpoint['state_dict'])optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)def train(epoch):model.train()for batch_idx, (data, target) in enumerate(tqdm(train_loader)):data, target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)optimizer.zero_grad()output = model(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()if batch_idx % log_interval == 0:print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))def test():model.eval()test_loss = 0correct = 0for data, target in tqdm(test_loader):data , target = data.cuda(), target.cuda()data, target = Variable(data), Variable(target)output = model(data)test_loss += F.cross_entropy(output, target, size_average=False).item()pred = output.data.max(1, keepdim=True)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))return correct / float(len(test_loader.dataset))def save_checkpoint(state, is_best, filename=prune_checkpoint_path):torch.save(state, filename)if is_best:shutil.copyfile(filename, prune_best_model_path)best_prec = 0
for epoch in range(epochs):train(epoch)prec = test()is_best = prec > best_precbest_prec = max(prec, best_prec)save_checkpoint({'epoch': epoch + 1,'state_dict': model.state_dict(),'best_prec': best_prec,'optimizer': optimizer.state_dict()}, is_best)
4、原始模型和剪枝后模型比较:
在cifar10上通过vgg模型分别迭代3次。
原始模型为156M,准确率为70%左右
剪枝后模型为36M,准确率为76%左右
备注:最好是原始模型达到顶峰时再剪枝,此时再比较剪枝前后的准确率影响。