文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 9b02e4db by 王肇一

mrnet

parent ca3e4a87
......@@ -42,7 +42,7 @@ def step_1(net, args, device, list, position):
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_2(list, position):
def step_2(list, position=1):
for num, dir in enumerate(list):
#df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values = []
......@@ -52,20 +52,21 @@ def step_2(list, position):
img = cv.imread('data/imgs/' + dir + '/' + name, 0)
mask = cv.imread('data/masks/' + dir + '/' + name, 0)
value = get_subarea_info(img, mask)
value, count = get_subarea_info(img, mask)
ug = 0.0
if str.lower(match_group.group(1)).endswith('ug'):
ug = float(str.lower(match_group.group(1))[:-2])
elif str.lower(match_group.group(1))== 'd2o':
ug = 0
elif str.lower(match_group.group(1)) == 'd2o':
ug = 0
elif str.lower(match_group.group(1)) == 'lb':
ug = -1
iter = str.lower(match_group.group(2))
values.append({'Intensity (a. u.)':value,'ug':ug,'iter':iter})
values.append({'Intensity (a. u.)': value, 'ug': ug, 'count': count})
df = pd.DataFrame(values)
df.sort_values('ug', inplace = True)
baseline = df[df['ug'] == 0]['Intensity (a. u.)'].mean()*0.62
df.replace(-1, 'lb', inplace = True)
df.replace(0, 'd2o', inplace = True)
baseline = df[df['ug'] == 'd2o']['Intensity (a. u.)'].mean()*0.62
sns.set_style("darkgrid")
sns.catplot(x = 'ug', y = 'Intensity (a. u.)', kind = 'bar', palette = 'vlag', data = df)
#sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
from .train import train_net
from .mrnet_module import MultiUnet
\ No newline at end of file
......@@ -2,8 +2,8 @@
# -*- coding:utf-8 -*-
import torch
import torch.nn.functional as F
from torch import nn
from .mrnet_parts import MultiResBlock,ResPath,TransCompose
class MultiUnet(nn.Module):
......@@ -14,80 +14,60 @@ class MultiUnet(nn.Module):
self.bilinear = bilinear
self.inconv = nn.Sequential(
nn.Conv2d(n_channels, 16, kernel_size = 3, stride = 1, padding_mode = 'same'),
nn.Conv2d(n_channels, 8, kernel_size = 3, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(8),
nn.ReLU(inplace = True)
)
self.res1 = MultiResBlock(16, 32)
self.res1 = MultiResBlock(8, 32)
self.res2 = MultiResBlock(self.res1.outc, 32*2)
self.res3 = MultiResBlock(self.res2.outc, 32*4)
self.res4 = MultiResBlock(self.res3.outc, 32*8)
self.res5 = MultiResBlock(self.res4.outc, 32*16)
self.res6 = MultiResBlock(self.res5.outc, 32*8)
self.res7 = MultiResBlock(self.res6.outc, 32*4)
self.res8 = MultiResBlock(self.res7.outc, 32*2)
self.res9 = MultiResBlock(self.res8.outc, 32)
self.path1_9 = ResPath(self.res1.outc, 32, 4)
self.path2_8 = ResPath(self.res2.outc, 32*2, 3)
self.path3_7 = ResPath(self.res3.outc, 32*4, 2)
self.path4_6 = ResPath(self.res4.outc, 32*8, 1)
self.up6 = TransCompose(self.res5.outc, 32 * 8)
self.res6 = MultiResBlock(self.up6.outc*2, 32 * 8)
self.up7 = TransCompose(self.res6.outc, 32 * 4)
self.res7 = MultiResBlock(self.up7.outc*2, 32 * 4)
self.up8 = TransCompose(self.res7.outc, 32 * 2)
self.res8 = MultiResBlock(self.up8.outc*2, 32 * 2)
self.up9 = TransCompose(self.res8.outc, 32)
self.res9 = MultiResBlock(self.up9.outc*2, 32)
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding_mode = 'same')
def forward(self, x):
pass
x = self.inconv(x)
res1 = self.res1(x)
pool1 = self.pool(res1)
res2 = self.res2(pool1)
pool2 = self.pool(res2)
class MultiResBlock(nn.Module):
def __init__(self, in_channels, U, alpha = 1.67):
super().__init__()
self.U = U
self.W = U*alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
res3 = self.res3(pool2)
pool3 = self.pool(res3)
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace = True)
)
res4 = self.res4(pool3)
pool4 = self.pool(res4)
def forward(self, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc,out_channels = self.outc,kernel_size = 1,stride = 1, padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 2)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut)
result = nn.Sequential(
nn.ReLU(inplace = True),
nn.BatchNorm2d(self.outc))(result)
return result
class ResPath(nn.Module):
def __init__(self, in_channels, out_channels, length):
super().__init__()
self.length = length
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1,padding_mode = 'same'),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
res5 = self.res5(pool4)
def forward(self, x):
x = self.unit(self.inc,self.outc, x)
for i in range(self.length -1):
x = self.unit(self.outc,self.outc,x)
return x
res1 = self.path1_9(res1)
res2 = self.path2_8(res2)
res3 = self.path3_7(res3)
res4 = self.path4_6(res4)
res6 = self.res6(torch.cat([res4, self.up6(res5)], 1))
res7 = self.res7(torch.cat([res3, self.up7(res6)], 1))
res8 = self.res8(torch.cat([res2, self.up8(res7)], 1))
res9 = self.res9(torch.cat([res1, self.up9(res8)], 1))
out = self.outconv(res9)
return out
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
class MultiResBlock(nn.Module):
def __init__(self, in_channels, U, alpha = 1.67):
super().__init__()
self.U = U
self.W = U * alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1,
padding_mode = 'same'), nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
def forward(self, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 1)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut)
result = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))(result)
return result
class TransCompose(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.inc = in_channels
self.outc = out_channels
def forward(self,x):
return nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 3, stride = 2),
nn.BatchNorm2d(self.outc)
)(x)
class ResPath(nn.Module):
def __init__(self, in_channels, out_channels, length):
super().__init__()
self.length = length
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
def forward(self, x):
x = self.unit(self.inc, self.outc, x)
for i in range(self.length - 1):
x = self.unit(self.outc, self.outc, x)
return x
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import os
import logging
from tqdm import tqdm
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from utils.dataset import BasicDataset
from utils.eval import eval_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()
for epoch in tqdm(range(epochs)):
net.train()
epoch_loss = 0
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score))
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
torch.save(net.state_dict(), 'MODEL.pth')
......@@ -65,21 +65,18 @@ def get_subarea_info(img, mask):
back_value = img[np.where(real_bg_mask != 0)]
back_mean = np.mean(back_value)
info.append({'id': i, 'size': area_size, 'area_mean': area_mean, 'back_mean': back_mean})
info.append({'mean': area_mean,'back':back_mean,'size':area_size})
# endif
df = pd.DataFrame(info)
df['Intensity (a. u.)'] = df['area_mean'] - df['back_mean']
median = np.median(df['Intensity (a. u.)'])
median = np.median(df['mean'])
b = 1.4826
mad = b * np.median(np.abs(df['Intensity (a. u.)'] - median))
mad = b * np.median(np.abs(df['mean'] - median))
lower_limit = median - (3 * mad)
upper_limit = median + (3 * mad)
#df = df[df['Intensity (a. u.)'] > lower_limit]
#df = df[df['Intensity (a. u.)'] < upper_limit]
value = df['Intensity (a. u.)'].mean()
df = df[df['mean'] >= lower_limit]
df = df[df['mean'] <= upper_limit]
df['value'] = df['mean']-df['back']
return value
return (df['value'] * df['size']).sum() / df['size'].sum(),df.shape[0]
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvision
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding = 1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace = True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.block = nn.Sequential(ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size = 3, stride = 2, padding = 1,
output_padding = 1), nn.ReLU(inplace = True))
def forward(self, x):
return self.block(x)
class Interpolate(nn.Module):
def __init__(self, size = None, scale_factor = None, mode = 'nearest', align_corners = False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode
self.scale_factor = scale_factor
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, size = self.size, scale_factor = self.scale_factor, mode = self.mode,
align_corners = self.align_corners)
return x
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv = True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size = 4, stride = 2, padding = 1),
nn.ReLU(inplace = True))
else:
self.block = nn.Sequential(Interpolate(scale_factor = 2, mode = 'bilinear'),
ConvRelu(in_channels, middle_channels), ConvRelu(middle_channels, out_channels), )
def forward(self, x):
return self.block(x)
class UNet11(nn.Module):
def __init__(self, num_filters = 32, pretrained = False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.encoder = models.vgg11(pretrained = pretrained).features
self.relu = self.encoder[1]
self.conv1 = self.encoder[0]
self.conv2 = self.encoder[3]
self.conv3s = self.encoder[6]
self.conv3 = self.encoder[8]
self.conv4s = self.encoder[11]
self.conv4 = self.encoder[13]
self.conv5s = self.encoder[16]
self.conv5 = self.encoder[18]
self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)
self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)
self.final = nn.Conv2d(num_filters, 1, kernel_size = 1)
def forward(self, x):
conv1 = self.relu(self.conv1(x))
conv2 = self.relu(self.conv2(self.pool(conv1)))
conv3s = self.relu(self.conv3s(self.pool(conv2)))
conv3 = self.relu(self.conv3(conv3s))
conv4s = self.relu(self.conv4s(self.pool(conv3)))
conv4 = self.relu(self.conv4(conv4s))
conv5s = self.relu(self.conv5s(self.pool(conv4)))
conv5 = self.relu(self.conv5(conv5s))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
return self.final(dec1)
class UNet16(nn.Module):
def __init__(self, num_classes = 1, num_filters = 32, pretrained = False, is_deconv = False):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG16
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super().__init__()
self.n_classes = num_classes
self.pool = nn.MaxPool2d(2, 2)
self.encoder = torchvision.models.vgg16(pretrained = pretrained).features
self.relu = nn.ReLU(inplace = True)
self.conv1 = nn.Sequential(self.encoder[0], self.relu, self.encoder[2], self.relu)
self.conv2 = nn.Sequential(self.encoder[5], self.relu, self.encoder[7], self.relu)
self.conv3 = nn.Sequential(self.encoder[10], self.relu, self.encoder[12], self.relu, self.encoder[14],
self.relu)
self.conv4 = nn.Sequential(self.encoder[17], self.relu, self.encoder[19], self.relu, self.encoder[21],
self.relu)
self.conv5 = nn.Sequential(self.encoder[24], self.relu, self.encoder[26], self.relu, self.encoder[28],
self.relu)
self.center = DecoderBlock(512, num_filters * 8 * 2, num_filters * 8)
self.dec5 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec4 = DecoderBlock(512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8)
self.dec3 = DecoderBlock(256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2)
self.dec2 = DecoderBlock(128 + num_filters * 2, num_filters * 2 * 2, num_filters)
self.dec1 = ConvRelu(64 + num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size = 1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(self.pool(conv1))
conv3 = self.conv3(self.pool(conv2))
conv4 = self.conv4(self.pool(conv3))
conv5 = self.conv5(self.pool(conv4))
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(torch.cat([dec2, conv1], 1))
if self.n_classes > 1:
x_out = F.log_softmax(self.final(dec1), dim = 1)
else:
x_out = self.final(dec1)
return x_out
......@@ -4,112 +4,21 @@ import os
import sys
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from utils.eval import eval_net
import unet
import mrnet
from mrnet import MultiUnet
from unet import UNet
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
#optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
def get_args():
parser = argparse.ArgumentParser(description = 'Train the UNet on images and target masks',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar = 'E', type = int, default = 5, help = 'Number of epochs',
parser.add_argument('-e', '--epochs', metavar = 'E', type = int, default = 1, help = 'Number of epochs',
dest = 'epochs')
parser.add_argument('-b', '--batch-size', metavar = 'B', type = int, nargs = '?', default = 1, help = 'Batch size',
dest = 'batchsize')
......@@ -117,7 +26,7 @@ def get_args():
help = 'Learning rate', dest = 'lr')
parser.add_argument('-f', '--load', dest = 'load', type = str, default = False,
help = 'Load model from a .pth file')
parser.add_argument('-s', '--scale', dest = 'scale', type = float, default = 0.5,
parser.add_argument('-s', '--scale', dest = 'scale', type = float, default = 2.56,
help = 'Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest = 'val', type = float, default = 10.0,
help = 'Percent of the data that is used as validation (0-100)')
......@@ -137,7 +46,10 @@ if __name__ == '__main__':
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels = 1, n_classes = 1)
#net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
......@@ -152,7 +64,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try:
train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
mrnet.train_net(net = net, device=device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
......
from .unet_model import UNet
from .train import train_net
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import argparse
import logging
import os
import sys
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
# criterion = nn.BCEWithLogitsLoss()
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 4 == 0:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
torch.save(net.state_dict(), 'MODEL.pth')
writer.close()
\ No newline at end of file
......@@ -13,7 +13,7 @@ class BasicDataset(Dataset):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
#assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
......@@ -25,7 +25,7 @@ class BasicDataset(Dataset):
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
newW, newH = 256,256 #int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment