Commit 86d39706 by 王肇一

Cycle learning rate for unet

parent 515970c7
......@@ -42,7 +42,8 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Sequential(
nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
nn.Sigmoid()
nn.Softmax()
#nn.Sigmoid()
)
# self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
......
......@@ -28,8 +28,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()#nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.StepLR(optimizer,30,0.5)#lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = nn.BCELoss()# nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.StepLR(optimizer, 30, 0.5)# lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
for epoch in range(epochs):
net.train()
......
......@@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from ignite.contrib.handlers.param_scheduler import LRScheduler
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, DiceCoefficient, ConfusionMatrix, RunningAverage
from ignite.metrics import Accuracy, Loss, DiceCoefficient, ConfusionMatrix, RunningAverage,mIoU
from ignite.contrib.handlers import ProgressBar
from argparse import ArgumentParser
......@@ -34,11 +34,13 @@ def run(train_batch_size, val_batch_size, epochs, lr):
optimizer = optim.Adam(model.parameters(), lr = lr)
cm = ConfusionMatrix(num_classes = 1)
dice = DiceCoefficient(cm)
iou = mIoU(cm)
loss = torch.nn.BCELoss() # torch.nn.NLLLoss()
scheduler = LRScheduler(lr_scheduler.ReduceLROnPlateau(optimizer))
scheduler = LRScheduler(lr_scheduler.StepLR(optimizer, 30, 0.5))
trainer = create_supervised_trainer(model, optimizer, loss, device = device)
evaluator = create_supervised_evaluator(model, metrics = {'accuracy': Accuracy(), 'dice': dice, 'nll': Loss(loss)},
evaluator = create_supervised_evaluator(model,
metrics = {'accuracy': Accuracy(), 'dice': dice, 'nll': Loss(loss)},
device = device)
RunningAverage(output_transform = lambda x: x).attach(trainer, 'loss')
trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler)
......
......@@ -47,8 +47,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# criterion = nn.BCEWithLogitsLoss()
#scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
scheduler = lr_scheduler.CyclicLR(optimizer, base_lr = 1e-10, max_lr = 0.01)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
......@@ -59,13 +59,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs,true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
......@@ -80,11 +73,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val)
scheduler.step(val_score)
#scheduler.step(val_score)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
......
......@@ -2,6 +2,7 @@ import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import jaccard_score
import numpy as np
from utils.dice_loss import dice_coeff, dice_coef
......@@ -42,6 +43,7 @@ def eval_jac(net, loader, device, n_val):
pred_masks = torch.round(pred_masks).cpu().detach().numpy()
true_masks = torch.round(true_masks).cpu().numpy()
pred_masks = np.array([1 if x>0 else 0 for x in pred_masks])
jac += jaccard_score(true_masks.flatten(), pred_masks.flatten())
pbar.update(imgs.shape[0])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment