Commit ca3e4a87 by 王肇一

mrnet

parent 77ddba8d
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn.functional as F
from torch import nn
class MultiUnet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear = True):
super(MultiUnet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inconv = nn.Sequential(
nn.Conv2d(n_channels, 16, kernel_size = 3, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(8),
nn.ReLU(inplace = True)
)
self.res1 = MultiResBlock(16, 32)
self.res2 = MultiResBlock(self.res1.outc, 32*2)
self.res3 = MultiResBlock(self.res2.outc, 32*4)
self.res4 = MultiResBlock(self.res3.outc, 32*8)
self.res5 = MultiResBlock(self.res4.outc, 32*16)
self.res6 = MultiResBlock(self.res5.outc, 32*8)
self.res7 = MultiResBlock(self.res6.outc, 32*4)
self.res8 = MultiResBlock(self.res7.outc, 32*2)
self.res9 = MultiResBlock(self.res8.outc, 32)
def forward(self, x):
pass
class MultiResBlock(nn.Module):
def __init__(self, in_channels, U, alpha = 1.67):
super().__init__()
self.U = U
self.W = U*alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace = True)
)
def forward(self, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc,out_channels = self.outc,kernel_size = 1,stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 2)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut)
result = nn.Sequential(
nn.ReLU(inplace = True),
nn.BatchNorm2d(self.outc))(result)
return result
class ResPath(nn.Module):
def __init__(self, in_channels, out_channels, length):
super().__init__()
self.length = length
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
def forward(self, x):
x = self.unit(self.inc,self.outc, x)
for i in range(self.length -1):
x = self.unit(self.outc,self.outc,x)
return x
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
from torch.nn import Module, Sequential, Conv2d, ReLU, AdaptiveMaxPool2d, AdaptiveAvgPool2d, NLLLoss, BCELoss, \
CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
from torch.nn import functional as F
from torch import nn
__all__ = ['PAM_Module', 'CAM_Module', 'semanticModule']
class _EncoderBlock(Module):
def __init__(self, in_channels, out_channels, dropout = False):
super(_EncoderBlock, self).__init__()
layers = [nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1), nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True), nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True), ]
if dropout:
layers.append(nn.Dropout())
layers.append(nn.MaxPool2d(kernel_size = 2, stride = 2))
self.encode = nn.Sequential(*layers)
def forward(self, x):
return self.encode(x)
class _DecoderBlock(Module):
def __init__(self, in_channels, middle_channels, out_channels):
super(_DecoderBlock, self).__init__()
self.decode = nn.Sequential(nn.Conv2d(in_channels, middle_channels, kernel_size = 3, padding = 1),
nn.BatchNorm2d(middle_channels), nn.ReLU(inplace = True),
nn.Conv2d(middle_channels, middle_channels, kernel_size = 3, padding = 1), nn.BatchNorm2d(middle_channels),
nn.ReLU(inplace = True), nn.ConvTranspose2d(middle_channels, out_channels, kernel_size = 2, stride = 2), )
def forward(self, x):
return self.decode(x)
class semanticModule(Module):
""" Semantic attention module"""
def __init__(self, in_dim):
super(semanticModule, self).__init__()
self.chanel_in = in_dim
self.enc1 = _EncoderBlock(in_dim, in_dim * 2)
self.enc2 = _EncoderBlock(in_dim * 2, in_dim * 4)
self.dec2 = _DecoderBlock(in_dim * 4, in_dim * 2, in_dim * 2)
self.dec1 = _DecoderBlock(in_dim * 2, in_dim, in_dim)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(enc1)
dec2 = self.dec2(enc2)
dec1 = self.dec1(F.upsample(dec2, enc1.size()[2:], mode = 'bilinear'))
return enc2.view(-1), dec1
class PAM_Module(Module):
""" Position attention module"""
# Ref from SAGAN
def __init__(self, in_dim):
super(PAM_Module, self).__init__()
self.chanel_in = in_dim
self.query_conv = Conv2d(in_channels = in_dim, out_channels = in_dim // 8, kernel_size = 1)
self.key_conv = Conv2d(in_channels = in_dim, out_channels = in_dim // 8, kernel_size = 1)
self.value_conv = Conv2d(in_channels = in_dim, out_channels = in_dim, kernel_size = 1)
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim = -1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
class CAM_Module(Module):
""" Channel attention module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = Parameter(torch.zeros(1))
self.softmax = Softmax(dim = -1)
def forward(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X C X C
"""
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim = True)[0].expand_as(energy) - energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch.nn.functional as F
from torch import nn
class Net(nn.Module):
def __init__(self, n_channels, n_classes, bilinear = True):
super(Net, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
def forward(self, x):
return logits
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment