Mobilnet(pytorch)


深度学习神经网络特征提取(六)

本次文章给出MobileNetPytorch版本的代码。关于网络的讲解部分,大家参考前期的文章

MobileNetv1

class DepthwiseSeparabel(nn.Module):
    def __init__(self,input_channel,output,stride=1):
        super(DepthwiseSeparabel,self).__init__()
        self.depth_wise = nn.Conv2d(input_channel,input_channel,kernel_size=3, stride = stride, padding =1,groups = input_channel)
        self.batch1 = nn.BatchNorm2d(input_channel)
        self.relu1 = nn.ReLU6(inplace = True)
        self.separable = nn.Conv2d(input_channel,output,kernel_size=1,stride = 1)
        self.batch2 = nn.BatchNorm2d(output)
        self.relu2  = nn.ReLU6(inplace = True)
    def forward(self,x):
        x = self.depth_wise(x)
        x = self.batch1(x)
        x = self.relu1(x)
        x = self.separable(x)
        x = self.batch2(x)
        x = self.relu2(x)
        return x
class MobileNetv1(nn.Module):
    def __init__(self,num_classes):
        super(MobileNetv1,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,kernel_size=3,stride = 2,padding =1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True),
            DepthwiseSeparabel(32,64),
            DepthwiseSeparabel(64,128,2),
            DepthwiseSeparabel(128,128),
            DepthwiseSeparabel(128,256,2),
            DepthwiseSeparabel(256,256),
            DepthwiseSeparabel(256,512,2),
            DepthwiseSeparabel(512,512),
            DepthwiseSeparabel(512,512),
            DepthwiseSeparabel(512,512),
            DepthwiseSeparabel(512,512),
            DepthwiseSeparabel(512,512),
            DepthwiseSeparabel(512,1024,2),
            DepthwiseSeparabel(1024,1024),
        )
        self.avg = nn.AdaptiveMaxPool2d(1)

        self.drop1 = nn.Dropout(0.5)
        self.linear1 = nn.Linear(1024,num_classes)
    def forward(self,x):
        x = self.model(x)
        x = self.avg(x)
        x = x.view(x.size(0),x.size(1))
        x = self.drop1(x)
        x = self.linear1(x)
        return x

MobileNetv2

class DepthwiseSeparabel(nn.Module):
    def __init__(self,input_channel,output,stride=1):
        super(DepthwiseSeparabel,self).__init__()
        self.depth_wise = nn.Conv2d(input_channel,input_channel,kernel_size=3, stride = stride, padding =1,groups = input_channel)
        self.batch1 = nn.BatchNorm2d(input_channel)
        self.relu1 = nn.ReLU6(inplace = True)
        self.separable = nn.Conv2d(input_channel,output,kernel_size=1,stride = 1)
        self.batch2 = nn.BatchNorm2d(output)
    def forward(self,x):
        x = self.depth_wise(x)
        x = self.batch1(x)
        x = self.relu1(x)
        x = self.separable(x)
        x = self.batch2(x)
        return x
class inverted_res_block(nn.Module):
    def __init__(self,input_channel,output,stride,expansion,first_inverted_res_block = False):
        super(inverted_res_block,self).__init__()
        if not first_inverted_res_block:
            self.model = nn.Sequential(
                nn.Conv2d(input_channel,expansion*input_channel,kernel_size=1,stride = 1),
                nn.BatchNorm2d(expansion*input_channel),
                nn.ReLU6()
            )
        else:
            self.model = nn.Sequential()
        self.depth_wise_separable = DepthwiseSeparabel(input_channel*expansion,output,stride = stride)
    def forward(self,x):
        input_data = x
        x = self.model(x)
        x = self.depth_wise_separable(x)
        if x.shape == input_data.shape:
            x += input_data
        return x
class MobileNetv2(nn.Module):
    def __init__(self,num_classes):
        super(MobileNetv2,self).__init__()
        # input = [3,224,224]
        self.model = nn.Sequential(
            nn.Conv2d(3,32,kernel_size=3, padding=1, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU6(),
            inverted_res_block(32,16,1,1,True),
            inverted_res_block(16,24,2,6),
            inverted_res_block(24,24,1,6),
            inverted_res_block(24,32,2,6),
            inverted_res_block(32,32,1,6),
            inverted_res_block(32,32,1,6),
            inverted_res_block(32,64,2,6),
            inverted_res_block(64,64,1,6),
            inverted_res_block(64,64,1,6),
            inverted_res_block(64,64,1,6),
            inverted_res_block(64,96,2,6),
            inverted_res_block(96,96,1,6),
            inverted_res_block(96,96,1,6),
            inverted_res_block(96,160,2,6),
            inverted_res_block(160,160,1,6),
            inverted_res_block(160,160,1,6),
            inverted_res_block(160,320,1,6),
            nn.Conv2d(320,1280,kernel_size=1, stride=1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(), 
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(1280,num_classes)
    def forward(self,x):
        x = self.model(x)
        x = x.view(x.size(0),x.size(1))
        x = self.fc(x)
        return x

MobileNetv3

class hard_swish(nn.Module):
    def __init__(self,inplace=True):
        super(hard_swish,self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self,x):
        x = x*self.relu(x+3.)/6.
        return x
class squeeze(nn.Module):
    def __init__(self,up_dim):
        super(squeeze,self).__init__()
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.model = nn.Sequential(
            nn.Linear(up_dim,up_dim//4),
            nn.ReLU6(inplace=True),
            nn.Linear(up_dim//4,up_dim),
            hard_swish(inplace=True)
        )
    def forward(self,x):
        input_data = x
        x = self.avg(x)
        x = x.view(input_data.size(0),input_data.size(1))
        x = self.model(x)
        x = x.view(input_data.size(0),input_data.size(1),1,1)
        return torch.mul(input_data,x)
class bottleneck(nn.Module):
    def __init__(self,input_channel,output,kernel_size,stride,up_dim,sq,activation_fun):
        super(bottleneck,self).__init__()
        self.conv1 = nn.Conv2d(input_channel,up_dim,kernel_size=1,stride=1)
        self.batch1 = nn.BatchNorm2d(up_dim)
        self.act_fun1 = hard_swish(inplace=True) if activation_fun == 'HS' else nn.ReLU6(inplace=True)
        self.depth_wise = nn.Conv2d(up_dim,up_dim,kernel_size=kernel_size, stride = stride, padding =(kernel_size-1)//2,groups = up_dim)
        self.batch2 = nn.BatchNorm2d(up_dim)
        self.act_fun2 = hard_swish(inplace=True) if activation_fun == 'HS' else nn.ReLU6(inplace=True)
        self.squeeze = nn.Sequential()
        if sq:
            self.squeeze = squeeze(up_dim)
        self.conv2 = nn.Conv2d(up_dim,output,kernel_size=1,stride = 1)
        self.batch3 = nn.BatchNorm2d(output)
    def forward(self,x):
        input_data = x
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.act_fun1(x)
        x = self.depth_wise(x)
        x = self.batch2(x)
        x = self.act_fun2(x)
        x = self.squeeze(x)
        x = self.conv2(x)
        x = self.batch3(x)
        if x.shape == input_data.shape:
            x +=input_data
        return x
class MobileNetv3_small(nn.Module):
    def __init__(self,num_classes):
        super(MobileNetv3_small,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(16),
            hard_swish(inplace=True),
            #16,112,112 ->16,56,56
            bottleneck(input_channel=16,output=16,kernel_size=3,stride=2,up_dim=16,sq=True,activation_fun='RE'),
            #16,56,56->24,28,28
            bottleneck(input_channel=16,output=24,kernel_size=3,stride=2,up_dim=72,sq=False,activation_fun='RE'),
            bottleneck(input_channel=24,output=24,kernel_size=3,stride=1,up_dim=88,sq=False,activation_fun='RE'),
            #24,28,28->40,14,14
            bottleneck(input_channel=24,output=40,kernel_size=5,stride=2,up_dim=96,sq=True,activation_fun='HS'),
            bottleneck(input_channel=40,output=40,kernel_size=5,stride=1,up_dim=240,sq=True,activation_fun='HS'),
            bottleneck(input_channel=40,output=40,kernel_size=5,stride=1,up_dim=240,sq=True,activation_fun='HS'),
            #40,14,14->48,14,14
            bottleneck(input_channel=40,output=48,kernel_size=5,stride=1,up_dim=120,sq=True,activation_fun='HS'),
            bottleneck(input_channel=48,output=48,kernel_size=5,stride=1,up_dim=144,sq=True,activation_fun='HS'),
            #48,14,14->96,7,7
            bottleneck(input_channel=48,output=96,kernel_size=5,stride=2,up_dim=288,sq=True,activation_fun='HS'),
            bottleneck(input_channel=96,output=96,kernel_size=5,stride=1,up_dim=576,sq=True,activation_fun='HS'),
            bottleneck(input_channel=96,output=96,kernel_size=5,stride=1,up_dim=576,sq=True,activation_fun='HS'),
            nn.Conv2d(96,576,kernel_size=1,stride=1),
            nn.BatchNorm2d(576),
            hard_swish(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.linear1 = nn.Linear(576,1024)
        self.act_fun1 = hard_swish(inplace=True)
        self.linear2 = nn.Linear(1024,num_classes)
    def forward(self,x):
        x = self.model(x)
        x = x.view(x.size(0),x.size(1))
        x = self.linear1(x)
        x = self.act_fun1(x)
        x = self.linear2(x)
        return x

至此,MobileNet网络的pytorch版本全部更新。
对了,如果有读者想要看一下网络结构的细节部分,此处提供一下代码,此处代码可以进行适当修改,无缝链接到其他文章中的网络。

net = MobileNet(10)
net.to(torch.device('cuda'))
input = torch.randn(10,3,224,224)
out = net(input)
#网络结构
print(net)
#输出参数
print(out.shape)
#网络细节
summary(net,(3,299,299))

文章作者: Fanrencli
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Fanrencli !
  目录