Commit af27b9fa by 王肇一

fine tune Unet & updating Unet to CE_Net

parent a59e70b3
......@@ -3,4 +3,6 @@ __pycache__/
data/imgs/*
data/masks/*
data/output/*
data/train_imgs/*
\ No newline at end of file
data/train_imgs/*
data/train_masks/*
.ipynb_checkpoints/
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from .ceunet_parts import DACblock, SPPblock, DecoderBlock, nonlinearity
class CEUnet(nn.Module):
def __init__(self, n_channels, n_classes):
super(CEUnet, self).__init__()
filters = [64, 128, 256, 512]
resnet = models.resnet34(pretrained = True)
weight = resnet.conv1.weight
self.inc = nn.Conv2d(in_channels = n_channels, out_channels = 64, kernel_size = 7, stride = 2, padding = 1)
self.inc.weight = nn.Parameter(weight[:, :1, :, :])
self.bn1 = resnet.bn1
self.relu1 = resnet.relu1
self.maxpool1 = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
self.dblock = DACblock(512)
self.spp = SPPblock(512)
self.decoder4 = DecoderBlock(516, filters[2])
self.decoder3 = DecoderBlock(filters[2], filters[1])
self.decoder2 = DecoderBlock(filters[1], filters[0])
self.decoder1 = DecoderBlock(filters[0], filters[0])
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 4, 2, 1)
self.finalrelu1 = nonlinearity
self.finalconv2 = nn.Conv2d(32, 32, 3, padding = 1)
self.finalrelu2 = nonlinearity
self.finalconv3 = nn.Conv2d(32, n_classes, 3, padding = 1)
def forward(self,x):
x = self.inc(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.maxpool1(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
db = self.dblock(e4)
spp = self.spp(db)
d4 = self.decoder4(spp)+e3
d3 = self.decoder3(d4)+e2
d2 = self.decoder2(d3)+e1
d1 = self.decoder1(d2)
out = self.finaldeconv1(d1)
out = self.finalrelu1(out)
out = self.finalconv2(out)
out = self.finalrelu2(out)
out = self.finalconv3(out)
return F.sigmoid(out)
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
nonlinearity = partial(F.relu, inplace = True)
class DACblock(nn.Module):
def __init__(self, channel):
super(DACblock, self).__init__()
self.dilate1 = nn.Conv2d(channel, channel, kernel_size = 3, dilation = 1, padding = 1)
self.dilate2 = nn.Conv2d(channel, channel, kernel_size = 3, dilation = 3, padding = 3)
self.dilate3 = nn.Conv2d(channel, channel, kernel_size = 3, dilation = 5, padding = 5)
self.conv1x1 = nn.Conv2d(channel, channel, kernel_size = 1, dilation = 1, padding = 0)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
dilate1_out = nonlinearity(self.dilate1(x))
dilate2_out = nonlinearity(self.conv1x1(self.dilate2(x)))
dilate3_out = nonlinearity(self.conv1x1(self.dilate2(self.dilate1(x))))
dilate4_out = nonlinearity(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x)))))
out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out
return out
class SPPblock(nn.Module):
def __init__(self, in_channels):
super(SPPblock, self).__init__()
self.pool1 = nn.MaxPool2d(kernel_size = [2, 2], stride = 2)
self.pool2 = nn.MaxPool2d(kernel_size = [3, 3], stride = 3)
self.pool3 = nn.MaxPool2d(kernel_size = [5, 5], stride = 5)
self.pool4 = nn.MaxPool2d(kernel_size = [6, 6], stride = 6)
self.conv = nn.Conv2d(in_channels = in_channels, out_channels = 1, kernel_size = 1, padding = 0)
def forward(self, x):
self.in_channels, h, w = x.size(1), x.size(2), x.size(3)
self.layer1 = F.upsample(self.conv(self.pool1(x)), size = (h, w), mode = 'bilinear')
self.layer2 = F.upsample(self.conv(self.pool2(x)), size = (h, w), mode = 'bilinear')
self.layer3 = F.upsample(self.conv(self.pool3(x)), size = (h, w), mode = 'bilinear')
self.layer4 = F.upsample(self.conv(self.pool4(x)), size = (h, w), mode = 'bilinear')
out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
return out
class DecoderBlock(nn.Module):
def __init__(self, in_channels, n_filters):
super(DecoderBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nonlinearity
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 3, stride = 2, padding = 1,
output_padding = 1)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
self.relu2 = nonlinearity
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nonlinearity
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)
return x
......@@ -16,27 +16,19 @@ from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
dir_mask = 'data/masks/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.1,
save_cp=True,
img_scale=0.5):
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
logging.info(f'''Starting training:
......@@ -50,7 +42,8 @@ def train_net(net,
Images scaling: {img_scale}
''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
#optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
......@@ -60,18 +53,19 @@ def train_net(net,
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
assert imgs.shape[
1] == net.n_channels, f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
f'' \
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
......@@ -107,34 +101,33 @@ def train_net(net,
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser = argparse.ArgumentParser(description = 'Train the UNet on images and target masks',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar = 'E', type = int, default = 5, help = 'Number of epochs',
dest = 'epochs')
parser.add_argument('-b', '--batch-size', metavar = 'B', type = int, nargs = '?', default = 1, help = 'Batch size',
dest = 'batchsize')
parser.add_argument('-l', '--learning-rate', metavar = 'LR', type = float, nargs = '?', default = 0.1,
help = 'Learning rate', dest = 'lr')
parser.add_argument('-f', '--load', dest = 'load', type = str, default = False,
help = 'Load model from a .pth file')
parser.add_argument('-s', '--scale', dest = 'scale', type = float, default = 0.5,
help = 'Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest = 'val', type = float, default = 10.0,
help = 'Percent of the data that is used as validation (0-100)')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logging.basicConfig(level = logging.INFO, format = '%(levelname)s: %(message)s')
args = get_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
......@@ -145,30 +138,23 @@ if __name__ == '__main__':
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=1, n_classes=1)
net = UNet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
net.load_state_dict(torch.load(args.load, map_location = device))
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
net.to(device = device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
train_net(net = net, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, device = device,
img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment