Add files via upload

This commit is contained in:
chenxgu
2020-04-27 23:01:44 -04:00
committed by GitHub
parent e27990cb75
commit fd82f35698

79
cv_model_architectures.py Normal file
View File

@@ -0,0 +1,79 @@
import torch, torchvision
import numpy as np
import torch.nn.functional as F
from torch import nn
class Model_resnet(nn.Module):
def __init__(self, freeze=False, pretrained=True):
super(Model_resnet, self).__init__()
if pretrained:
self.resnet = torchvision.models.resnet34(pretrained=True)
else:
self.resnet = torchvision.models.resnet34(pretrained=False)
if freeze:
for param in self.resnet.parameters():
param.requires_grad = False
self.resnet.fc = torch.nn.Linear(512, 1024)
self.addition = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dropout(p=0.5),
nn.Linear(1024, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(p=0.5),
nn.Linear(512, 6))
def forward(self, image):
z = self.resnet(image)
z = self.addition(z)
return z
class Model_inception(nn.Module):
def __init__(self, freeze=False, pretrained=True):
super(Model_inception, self).__init__()
if pretrained:
self.inceptionnet = torchvision.models.inception_v3(pretrained=True)
else:
self.inceptionnet = torchvision.models.inception_v3(pretrained=False)
if freeze:
for param in self.inceptionnet.parameters():
param.requires_grad = False
self.inceptionnet.aux_logits = False
self.inceptionnet.fc = torch.nn.Linear(2048, 1024)
self.addition = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dropout(p=0.5),
nn.Linear(1024, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dropout(p=0.5),
nn.Linear(512, 6))
def forward(self, image):
z = self.inceptionnet(image)
z = self.addition(z)
return z
class Model_mobilenet(nn.Module):
def __init__(self, freeze=False, pretrained=True):
super(Model_mobilenet, self).__init__()
if pretrained:
self.mobilenet = torchvision.models.mobilenet_v2(pretrained=True)
else:
self.mobilenet = torchvision.models.mobilenet_v2(pretrained=False)
if freeze:
for param in self.mobilenet.parameters():
param.requires_grad = False
self.mobilenet.classifier = nn.Sequential(
nn.Dropout(p=0.2, inplace=False),
nn.Linear(1280, 6, bias=True))
def forward(self, image):
z = self.mobilenet(image)
return z