文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 9b02e4db by 王肇一

mrnet

parent ca3e4a87
...@@ -42,7 +42,7 @@ def step_1(net, args, device, list, position): ...@@ -42,7 +42,7 @@ def step_1(net, args, device, list, position):
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result) cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_2(list, position): def step_2(list, position=1):
for num, dir in enumerate(list): for num, dir in enumerate(list):
#df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)')) #df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values = [] values = []
...@@ -52,20 +52,21 @@ def step_2(list, position): ...@@ -52,20 +52,21 @@ def step_2(list, position):
img = cv.imread('data/imgs/' + dir + '/' + name, 0) img = cv.imread('data/imgs/' + dir + '/' + name, 0)
mask = cv.imread('data/masks/' + dir + '/' + name, 0) mask = cv.imread('data/masks/' + dir + '/' + name, 0)
value = get_subarea_info(img, mask) value, count = get_subarea_info(img, mask)
ug = 0.0 ug = 0.0
if str.lower(match_group.group(1)).endswith('ug'): if str.lower(match_group.group(1)).endswith('ug'):
ug = float(str.lower(match_group.group(1))[:-2]) ug = float(str.lower(match_group.group(1))[:-2])
elif str.lower(match_group.group(1))== 'd2o':
ug = 0
elif str.lower(match_group.group(1)) == 'd2o': elif str.lower(match_group.group(1)) == 'd2o':
ug = 0
elif str.lower(match_group.group(1)) == 'lb':
ug = -1 ug = -1
iter = str.lower(match_group.group(2)) values.append({'Intensity (a. u.)': value, 'ug': ug, 'count': count})
values.append({'Intensity (a. u.)':value,'ug':ug,'iter':iter})
df = pd.DataFrame(values) df = pd.DataFrame(values)
df.sort_values('ug', inplace = True) df.sort_values('ug', inplace = True)
baseline = df[df['ug'] == 0]['Intensity (a. u.)'].mean()*0.62 df.replace(-1, 'lb', inplace = True)
df.replace(0, 'd2o', inplace = True)
baseline = df[df['ug'] == 'd2o']['Intensity (a. u.)'].mean()*0.62
sns.set_style("darkgrid") sns.set_style("darkgrid")
sns.catplot(x = 'ug', y = 'Intensity (a. u.)', kind = 'bar', palette = 'vlag', data = df) sns.catplot(x = 'ug', y = 'Intensity (a. u.)', kind = 'bar', palette = 'vlag', data = df)
#sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0) #sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file from .train import train_net
from .mrnet_module import MultiUnet
\ No newline at end of file
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import torch import torch
import torch.nn.functional as F
from torch import nn from torch import nn
from .mrnet_parts import MultiResBlock,ResPath,TransCompose
class MultiUnet(nn.Module): class MultiUnet(nn.Module):
...@@ -14,80 +14,60 @@ class MultiUnet(nn.Module): ...@@ -14,80 +14,60 @@ class MultiUnet(nn.Module):
self.bilinear = bilinear self.bilinear = bilinear
self.inconv = nn.Sequential( self.inconv = nn.Sequential(
nn.Conv2d(n_channels, 16, kernel_size = 3, stride = 1, padding_mode = 'same'), nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(8), nn.BatchNorm2d(8),
nn.ReLU(inplace = True) nn.ReLU(inplace = True)
) )
self.res1 = MultiResBlock(16, 32) self.res1 = MultiResBlock(8, 32)
self.res2 = MultiResBlock(self.res1.outc, 32*2) self.res2 = MultiResBlock(self.res1.outc, 32*2)
self.res3 = MultiResBlock(self.res2.outc, 32*4) self.res3 = MultiResBlock(self.res2.outc, 32*4)
self.res4 = MultiResBlock(self.res3.outc, 32*8) self.res4 = MultiResBlock(self.res3.outc, 32*8)
self.res5 = MultiResBlock(self.res4.outc, 32*16) self.res5 = MultiResBlock(self.res4.outc, 32*16)
self.res6 = MultiResBlock(self.res5.outc, 32*8) self.path1_9 = ResPath(self.res1.outc, 32, 4)
self.res7 = MultiResBlock(self.res6.outc, 32*4) self.path2_8 = ResPath(self.res2.outc, 32*2, 3)
self.res8 = MultiResBlock(self.res7.outc, 32*2) self.path3_7 = ResPath(self.res3.outc, 32*4, 2)
self.res9 = MultiResBlock(self.res8.outc, 32) self.path4_6 = ResPath(self.res4.outc, 32*8, 1)
self.up6 = TransCompose(self.res5.outc, 32 * 8)
self.res6 = MultiResBlock(self.up6.outc*2, 32 * 8)
self.up7 = TransCompose(self.res6.outc, 32 * 4)
self.res7 = MultiResBlock(self.up7.outc*2, 32 * 4)
self.up8 = TransCompose(self.res7.outc, 32 * 2)
self.res8 = MultiResBlock(self.up8.outc*2, 32 * 2)
self.up9 = TransCompose(self.res8.outc, 32)
self.res9 = MultiResBlock(self.up9.outc*2, 32)
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding_mode = 'same')
def forward(self, x): def forward(self, x):
pass x = self.inconv(x)
res1 = self.res1(x)
pool1 = self.pool(res1)
res2 = self.res2(pool1)
pool2 = self.pool(res2)
class MultiResBlock(nn.Module): res3 = self.res3(pool2)
def __init__(self, in_channels, U, alpha = 1.67): pool3 = self.pool(res3)
super().__init__()
self.U = U
self.W = U*alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
def conv(self, in_channel, out_channel, kernel_size): res4 = self.res4(pool3)
return nn.Sequential( pool4 = self.pool(res4)
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace = True)
)
def forward(self, x): res5 = self.res5(pool4)
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc,out_channels = self.outc,kernel_size = 1,stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 2)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut)
result = nn.Sequential(
nn.ReLU(inplace = True),
nn.BatchNorm2d(self.outc))(result)
return result
class ResPath(nn.Module):
def __init__(self, in_channels, out_channels, length):
super().__init__()
self.length = length
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
def forward(self, x): res1 = self.path1_9(res1)
x = self.unit(self.inc,self.outc, x) res2 = self.path2_8(res2)
for i in range(self.length -1): res3 = self.path3_7(res3)
x = self.unit(self.outc,self.outc,x) res4 = self.path4_6(res4)
return x
res6 = self.res6(torch.cat([res4, self.up6(res5)], 1))
res7 = self.res7(torch.cat([res3, self.up7(res6)], 1))
res8 = self.res8(torch.cat([res2, self.up8(res7)], 1))
res9 = self.res9(torch.cat([res1, self.up9(res8)], 1))
out = self.outconv(res9)
return out
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
class MultiResBlock(nn.Module):
def __init__(self, in_channels, U, alpha = 1.67):
super().__init__()
self.U = U
self.W = U * alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1,
padding_mode = 'same'), nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
def forward(self, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 1)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut)
result = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))(result)
return result
class TransCompose(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.inc = in_channels
self.outc = out_channels
def forward(self,x):
return nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 3, stride = 2),
nn.BatchNorm2d(self.outc)
)(x)
class ResPath(nn.Module):
def __init__(self, in_channels, out_channels, length):
super().__init__()
self.length = length
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
def forward(self, x):
x = self.unit(self.inc, self.outc, x)
for i in range(self.length - 1):
x = self.unit(self.outc, self.outc, x)
return x
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file
import os
import logging
from tqdm import tqdm
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from utils.dataset import BasicDataset
from utils.eval import eval_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()
for epoch in tqdm(range(epochs)):
net.train()
epoch_loss = 0
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score))
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
torch.save(net.state_dict(), 'MODEL.pth')
...@@ -65,21 +65,18 @@ def get_subarea_info(img, mask): ...@@ -65,21 +65,18 @@ def get_subarea_info(img, mask):
back_value = img[np.where(real_bg_mask != 0)] back_value = img[np.where(real_bg_mask != 0)]
back_mean = np.mean(back_value) back_mean = np.mean(back_value)
info.append({'mean': area_mean,'back':back_mean,'size':area_size})
info.append({'id': i, 'size': area_size, 'area_mean': area_mean, 'back_mean': back_mean})
# endif # endif
df = pd.DataFrame(info) df = pd.DataFrame(info)
df['Intensity (a. u.)'] = df['area_mean'] - df['back_mean']
median = np.median(df['Intensity (a. u.)']) median = np.median(df['mean'])
b = 1.4826 b = 1.4826
mad = b * np.median(np.abs(df['Intensity (a. u.)'] - median)) mad = b * np.median(np.abs(df['mean'] - median))
lower_limit = median - (3 * mad) lower_limit = median - (3 * mad)
upper_limit = median + (3 * mad) upper_limit = median + (3 * mad)
#df = df[df['Intensity (a. u.)'] > lower_limit] df = df[df['mean'] >= lower_limit]
#df = df[df['Intensity (a. u.)'] < upper_limit] df = df[df['mean'] <= upper_limit]
df['value'] = df['mean']-df['back']
value = df['Intensity (a. u.)'].mean()
return value return (df['value'] * df['size']).sum() / df['size'].sum(),df.shape[0]
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding = 1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace = True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.block = nn.Sequential(ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size = 3, stride = 2, padding = 1,
output_padding = 1), nn.ReLU(inplace = True))
def forward(self, x):
return self.block(x)
class Interpolate(nn.Module):
def __init__(self, size = None, scale_factor = None, mode = 'nearest', align_corners = False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
self.scale_factor = scale_factor
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, size = self.size, scale_factor = self.scale_factor, mode = self.mode,
align_corners = self.align_corners)
return x
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv = True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(inplace = True))
else:
self.block = nn.Sequential(Interpolate(scale_factor = 2, mode = 'bilinear'),
ConvRelu(in_channels, middle_channels), ConvRelu(middle_channels, out_channels), )
def forward(self, x):
return self.block(x)
class UNet11(nn.Module):
def __init__(self, num_filters = 32, pretrained = False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.encoder = models.vgg11(pretrained = pretrained).features
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, 1, kernel_size = 1)
def forward(self, x):
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(dec1)
class UNet16(nn.Module):
def __init__(self, num_classes = 1, num_filters = 32, pretrained = False, is_deconv = False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG16
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.n_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.vgg16(pretrained = pretrained).features
self.relu = nn.ReLU(inplace = True)
self.conv1 = nn.Sequential(self.encoder[0], self.relu, self.encoder[2], self.relu)
self.conv2 = nn.Sequential(self.encoder[5], self.relu, self.encoder[7], self.relu)
self.conv3 = nn.Sequential(self.encoder[10], self.relu, self.encoder[12], self.relu, self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17], self.relu, self.encoder[19], self.relu, self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24], self.relu, self.encoder[26], self.relu, self.encoder[28],
self.relu)
self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size = 1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
if self.n_classes > 1:
x_out = F.log_softmax(self.final(dec1), dim = 1)
else:
x_out = self.final(dec1)
return x_out
...@@ -4,112 +4,21 @@ import os ...@@ -4,112 +4,21 @@ import os
import sys import sys
import torch import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from utils.eval import eval_net import unet
import mrnet
from mrnet import MultiUnet
from unet import UNet from unet import UNet
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/' dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/' dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/' dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
#optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
def get_args(): def get_args():
parser = argparse.ArgumentParser(description = 'Train the UNet on images and target masks', parser = argparse.ArgumentParser(description = 'Train the UNet on images and target masks',
formatter_class = argparse.ArgumentDefaultsHelpFormatter) formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar = 'E', type = int, default = 5, help = 'Number of epochs', parser.add_argument('-e', '--epochs', metavar = 'E', type = int, default = 1, help = 'Number of epochs',
dest = 'epochs') dest = 'epochs')
parser.add_argument('-b', '--batch-size', metavar = 'B', type = int, nargs = '?', default = 1, help = 'Batch size', parser.add_argument('-b', '--batch-size', metavar = 'B', type = int, nargs = '?', default = 1, help = 'Batch size',
dest = 'batchsize') dest = 'batchsize')
...@@ -117,7 +26,7 @@ def get_args(): ...@@ -117,7 +26,7 @@ def get_args():
help = 'Learning rate', dest = 'lr') help = 'Learning rate', dest = 'lr')
parser.add_argument('-f', '--load', dest = 'load', type = str, default = False, parser.add_argument('-f', '--load', dest = 'load', type = str, default = False,
help = 'Load model from a .pth file') help = 'Load model from a .pth file')
parser.add_argument('-s', '--scale', dest = 'scale', type = float, default = 0.5, parser.add_argument('-s', '--scale', dest = 'scale', type = float, default = 2.56,
help = 'Downscaling factor of the images') help = 'Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest = 'val', type = float, default = 10.0, parser.add_argument('-v', '--validation', dest = 'val', type = float, default = 10.0,
help = 'Percent of the data that is used as validation (0-100)') help = 'Percent of the data that is used as validation (0-100)')
...@@ -137,7 +46,10 @@ if __name__ == '__main__': ...@@ -137,7 +46,10 @@ if __name__ == '__main__':
# - For 1 class and background, use n_classes=1 # - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1 # - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N # - For N > 2 classes, use n_classes=N
net = UNet(n_channels = 1, n_classes = 1)
#net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n' logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n' f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n' f'\t{net.n_classes} output channels (classes)\n'
...@@ -152,7 +64,7 @@ if __name__ == '__main__': ...@@ -152,7 +64,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, mrnet.train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100) img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
......
from .unet_model import UNet from .unet_model import UNet
from .train import train_net
\ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
\ No newline at end of file import argparse
import logging
import os
import sys
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
# criterion = nn.BCEWithLogitsLoss()
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 4 == 0:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
torch.save(net.state_dict(), 'MODEL.pth')
writer.close()
\ No newline at end of file
...@@ -13,7 +13,7 @@ class BasicDataset(Dataset): ...@@ -13,7 +13,7 @@ class BasicDataset(Dataset):
self.imgs_dir = imgs_dir self.imgs_dir = imgs_dir
self.masks_dir = masks_dir self.masks_dir = masks_dir
self.scale = scale self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1' #assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir) self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')] if not file.startswith('.')]
...@@ -25,7 +25,7 @@ class BasicDataset(Dataset): ...@@ -25,7 +25,7 @@ class BasicDataset(Dataset):
@classmethod @classmethod
def preprocess(cls, pil_img, scale): def preprocess(cls, pil_img, scale):
w, h = pil_img.size w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h) newW, newH = 256,256 #int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small' assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH)) pil_img = pil_img.resize((newW, newH))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment