mirror of
https://github.com/EECS-467-W20-RRRobot-Project/RRRobot.git
synced 2025-08-13 21:48:32 +00:00
Add files via upload
This commit is contained in:
79
cv_model_architectures.py
Normal file
79
cv_model_architectures.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import torch, torchvision
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model_resnet(nn.Module):
|
||||
def __init__(self, freeze=False, pretrained=True):
|
||||
super(Model_resnet, self).__init__()
|
||||
if pretrained:
|
||||
self.resnet = torchvision.models.resnet34(pretrained=True)
|
||||
else:
|
||||
self.resnet = torchvision.models.resnet34(pretrained=False)
|
||||
if freeze:
|
||||
for param in self.resnet.parameters():
|
||||
param.requires_grad = False
|
||||
self.resnet.fc = torch.nn.Linear(512, 1024)
|
||||
self.addition = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(512, 6))
|
||||
|
||||
def forward(self, image):
|
||||
z = self.resnet(image)
|
||||
z = self.addition(z)
|
||||
return z
|
||||
|
||||
|
||||
|
||||
class Model_inception(nn.Module):
|
||||
def __init__(self, freeze=False, pretrained=True):
|
||||
super(Model_inception, self).__init__()
|
||||
if pretrained:
|
||||
self.inceptionnet = torchvision.models.inception_v3(pretrained=True)
|
||||
else:
|
||||
self.inceptionnet = torchvision.models.inception_v3(pretrained=False)
|
||||
if freeze:
|
||||
for param in self.inceptionnet.parameters():
|
||||
param.requires_grad = False
|
||||
self.inceptionnet.aux_logits = False
|
||||
self.inceptionnet.fc = torch.nn.Linear(2048, 1024)
|
||||
self.addition = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(1024, 512),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.Dropout(p=0.5),
|
||||
nn.Linear(512, 6))
|
||||
|
||||
def forward(self, image):
|
||||
z = self.inceptionnet(image)
|
||||
z = self.addition(z)
|
||||
return z
|
||||
|
||||
|
||||
class Model_mobilenet(nn.Module):
|
||||
def __init__(self, freeze=False, pretrained=True):
|
||||
super(Model_mobilenet, self).__init__()
|
||||
if pretrained:
|
||||
self.mobilenet = torchvision.models.mobilenet_v2(pretrained=True)
|
||||
else:
|
||||
self.mobilenet = torchvision.models.mobilenet_v2(pretrained=False)
|
||||
if freeze:
|
||||
for param in self.mobilenet.parameters():
|
||||
param.requires_grad = False
|
||||
self.mobilenet.classifier = nn.Sequential(
|
||||
nn.Dropout(p=0.2, inplace=False),
|
||||
nn.Linear(1280, 6, bias=True))
|
||||
|
||||
def forward(self, image):
|
||||
z = self.mobilenet(image)
|
||||
return z
|
Reference in New Issue
Block a user