Commit 86d39706 by 王肇一

Cycle learning rate for unet

parent 515970c7
...@@ -42,7 +42,8 @@ class MultiUnet(nn.Module): ...@@ -42,7 +42,8 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
self.outconv = nn.Sequential( self.outconv = nn.Sequential(
nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1), nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
nn.Sigmoid() nn.Softmax()
#nn.Sigmoid()
) )
# self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1) # self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
......
...@@ -28,8 +28,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -28,8 +28,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr) optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()#nn.BCEWithLogitsLoss() criterion = nn.BCELoss()# nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.StepLR(optimizer,30,0.5)#lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = lr_scheduler.StepLR(optimizer, 30, 0.5)# lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
......
...@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader ...@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from ignite.contrib.handlers.param_scheduler import LRScheduler from ignite.contrib.handlers.param_scheduler import LRScheduler
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, DiceCoefficient, ConfusionMatrix, RunningAverage from ignite.metrics import Accuracy, Loss, DiceCoefficient, ConfusionMatrix, RunningAverage,mIoU
from ignite.contrib.handlers import ProgressBar from ignite.contrib.handlers import ProgressBar
from argparse import ArgumentParser from argparse import ArgumentParser
...@@ -34,11 +34,13 @@ def run(train_batch_size, val_batch_size, epochs, lr): ...@@ -34,11 +34,13 @@ def run(train_batch_size, val_batch_size, epochs, lr):
optimizer = optim.Adam(model.parameters(), lr = lr) optimizer = optim.Adam(model.parameters(), lr = lr)
cm = ConfusionMatrix(num_classes = 1) cm = ConfusionMatrix(num_classes = 1)
dice = DiceCoefficient(cm) dice = DiceCoefficient(cm)
iou = mIoU(cm)
loss = torch.nn.BCELoss() # torch.nn.NLLLoss() loss = torch.nn.BCELoss() # torch.nn.NLLLoss()
scheduler = LRScheduler(lr_scheduler.ReduceLROnPlateau(optimizer)) scheduler = LRScheduler(lr_scheduler.StepLR(optimizer, 30, 0.5))
trainer = create_supervised_trainer(model, optimizer, loss, device = device) trainer = create_supervised_trainer(model, optimizer, loss, device = device)
evaluator = create_supervised_evaluator(model, metrics = {'accuracy': Accuracy(), 'dice': dice, 'nll': Loss(loss)}, evaluator = create_supervised_evaluator(model,
metrics = {'accuracy': Accuracy(), 'dice': dice, 'nll': Loss(loss)},
device = device) device = device)
RunningAverage(output_transform = lambda x: x).attach(trainer, 'loss') RunningAverage(output_transform = lambda x: x).attach(trainer, 'loss')
trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler) trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler)
......
...@@ -47,8 +47,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -47,8 +47,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8) # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8) optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min') #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# criterion = nn.BCEWithLogitsLoss() scheduler = lr_scheduler.CyclicLR(optimizer, base_lr = 1e-10, max_lr = 0.01)
if net.n_classes > 1: if net.n_classes > 1:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
else: else:
...@@ -59,13 +59,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -59,13 +59,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
epoch_loss = 0 epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs,true_masks in train_loader: for imgs,true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
...@@ -80,11 +73,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -80,11 +73,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
scheduler.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0: global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
scheduler.step(val_score) #scheduler.step(val_score)
if net.n_classes > 1: if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step) writer.add_scalar('Loss/test', val_score, global_step)
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tqdm import tqdm from tqdm import tqdm
from sklearn.metrics import jaccard_score from sklearn.metrics import jaccard_score
import numpy as np
from utils.dice_loss import dice_coeff, dice_coef from utils.dice_loss import dice_coeff, dice_coef
...@@ -42,6 +43,7 @@ def eval_jac(net, loader, device, n_val): ...@@ -42,6 +43,7 @@ def eval_jac(net, loader, device, n_val):
pred_masks = torch.round(pred_masks).cpu().detach().numpy() pred_masks = torch.round(pred_masks).cpu().detach().numpy()
true_masks = torch.round(true_masks).cpu().numpy() true_masks = torch.round(true_masks).cpu().numpy()
pred_masks = np.array([1 if x>0 else 0 for x in pred_masks])
jac += jaccard_score(true_masks.flatten(), pred_masks.flatten()) jac += jaccard_score(true_masks.flatten(), pred_masks.flatten())
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment