文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 447cf1b5 by 王肇一

202008 update

parent 597d01e2
...@@ -14,12 +14,13 @@ import argparse ...@@ -14,12 +14,13 @@ import argparse
import logging import logging
import os import os
import re import re
from collections import OrderedDict
from unet import UNet from unet import UNet
from mrnet import MultiUnet from mrnet import MultiUnet
from utils.predict import predict_img,predict from utils.predict import predict_img,predict
from resCalc import save_img, get_subarea_info, save_img_mask,get_subarea_info_avgBG, get_subarea_info_fast, \ from resCalc import save_img, get_subarea_info, save_img_mask, get_subarea_info_avgBG, get_subarea_info_fast_outlier, \
get_subarea_info_fast_outlier get_subarea_info_nobg
def divide_list(list, step): def divide_list(list, step):
...@@ -52,7 +53,7 @@ def step_1_32bit(net,args,device,list,position): ...@@ -52,7 +53,7 @@ def step_1_32bit(net,args,device,list,position):
norm = cv.normalize(np.array(img), None, 0, 255, cv.NORM_MINMAX, cv.CV_8U) norm = cv.normalize(np.array(img), None, 0, 255, cv.NORM_MINMAX, cv.CV_8U)
norm = Image.fromarray(norm.astype('uint8')) norm = Image.fromarray(norm.astype('uint8'))
mask = predict_img(net = net, full_img = norm, out_threshold = args.mask_threshold, device = device) mask = predict_img(net = net, full_img = norm, out_threshold = args.mask_threshold, device = device)
# mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device) #mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8) result = (mask * 255).astype(np.uint8)
# save_img({'ori': img, 'mask': result}, fn[0], fn[1]) # save_img({'ori': img, 'mask': result}, fn[0], fn[1])
...@@ -74,8 +75,6 @@ def step_2(list, position=1): ...@@ -74,8 +75,6 @@ def step_2(list, position=1):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name) match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, 0) img = cv.imread('data/imgs/' + dir + '/' + name, 0)
mask = cv.imread('data/masks/' + dir + '/' + name, 0) mask = cv.imread('data/masks/' + dir + '/' + name, 0)
# value = get_subarea_info_fast(img, mask)
# value, count = get_subarea_info(img, mask)
value = get_subarea_info_avgBG(img, mask) value = get_subarea_info_avgBG(img, mask)
if value is not None: if value is not None:
ug = 0.0 ug = 0.0
...@@ -99,7 +98,7 @@ def step_2(list, position=1): ...@@ -99,7 +98,7 @@ def step_2(list, position=1):
plt.savefig('data/output/'+dir+'.png') plt.savefig('data/output/'+dir+'.png')
def step_2_32bit(list,position=1): def step_2_32bit(list, position=1):
for num, dir in enumerate(list): for num, dir in enumerate(list):
# df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)')) # df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values = [] values = []
...@@ -107,9 +106,14 @@ def step_2_32bit(list,position=1): ...@@ -107,9 +106,14 @@ def step_2_32bit(list,position=1):
for name in tqdm(names, desc = f'Period{num + 1}/{len(list)}', position = position): for name in tqdm(names, desc = f'Period{num + 1}/{len(list)}', position = position):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name) match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, flags = cv.IMREAD_ANYDEPTH) img = cv.imread('data/imgs/' + dir + '/' + name, flags = cv.IMREAD_ANYDEPTH)
# img_mean = np.mean(img)
# if np.max(img) + 125 -img_mean < 255:
# img = img+ 125 - img_mean
# else:
# img = img+ 255 - np.max(img)
mask = cv.imread('data/masks/' + dir + '/' + name, 0) mask = cv.imread('data/masks/' + dir + '/' + name, 0)
value = get_subarea_info_fast_outlier(img, mask) value = get_subarea_info_fast_outlier(img, mask)# get_subarea_info_nobg(img,mask)
#value,shape = get_subarea_info(img,mask)
if value is not None: if value is not None:
ug = 0.0 ug = 0.0
if str.lower(match_group.group(1)).endswith('ug'): if str.lower(match_group.group(1)).endswith('ug'):
...@@ -129,6 +133,7 @@ def step_2_32bit(list,position=1): ...@@ -129,6 +133,7 @@ def step_2_32bit(list,position=1):
sns.set_style("darkgrid") sns.set_style("darkgrid")
sns.catplot(x = 'ug', y = 'Intensity(a.u.)', kind = 'bar', palette = 'vlag', data = df) sns.catplot(x = 'ug', y = 'Intensity(a.u.)', kind = 'bar', palette = 'vlag', data = df)
# sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0) # sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
plt.suptitle(dir)
plt.axhline(y = baseline_high) plt.axhline(y = baseline_high)
plt.axhline(y = baseline_low) plt.axhline(y = baseline_low)
plt.savefig('data/output/' + dir + '.png') plt.savefig('data/output/' + dir + '.png')
...@@ -166,7 +171,14 @@ def cli_main(): ...@@ -166,7 +171,14 @@ def cli_main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
net.to(device = device) net.to(device = device)
net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
model = torch.load('data/module/' + args.module + '.pth',map_location = device)
d = OrderedDict()
for key, value in model.items():
tmp = key[7:]
d[tmp] = value
net.load_state_dict(d)
#net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !") logging.info("Model loaded !")
pool = Pool(args.process) pool = Pool(args.process)
...@@ -184,28 +196,6 @@ def cli_main(): ...@@ -184,28 +196,6 @@ def cli_main():
pool.apply_async(step_2_32bit, args = (list, i)) pool.apply_async(step_2_32bit, args = (list, i))
pool.close() pool.close()
pool.join() pool.join()
elif args.step == 3:
net = UNet(n_channels = 1, n_classes = 1)
# net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device = device)
net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !")
pool = Pool(args.process)
for i, list in enumerate(seperate_path):
pool.apply_async(step_1, args = (net, args, device, list, i))
pool.close()
pool.join()
dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))]
sep_dir = divide_list(dir, args.process)
for i, list in enumerate(sep_dir):
pool.apply_async(step_2, args = (list, i))
pool.close()
pool.join()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -42,8 +42,8 @@ class MultiUnet(nn.Module): ...@@ -42,8 +42,8 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
self.outconv = nn.Sequential( self.outconv = nn.Sequential(
nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1), nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
nn.Softmax() #nn.Softmax()
#nn.Sigmoid() nn.Sigmoid()
) )
# self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1) # self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
......
...@@ -4,16 +4,20 @@ ...@@ -4,16 +4,20 @@
import os import os
import logging import logging
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from torch.optim.rmsprop import RMSprop
from torchvision import transforms from torchvision import transforms
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset, VOCSegmentation from utils.dataset import BasicDataset, VOCSegmentation
from utils.eval import eval_net, eval_jac from utils.eval import eval_net, eval_jac
from utils.dice_loss import DiceLoss
from utils.focal_loss import FocalLoss
dir_checkpoint = 'checkpoint/' dir_checkpoint = 'checkpoint/'
...@@ -27,9 +31,14 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -27,9 +31,14 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True) train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr) writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}')
criterion = nn.BCELoss()# nn.BCEWithLogitsLoss() global_step = 0
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5)
#optimizer = optim.Adam(net.parameters(), lr = lr)
optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum = 0.99)
#criterion = nn.BCELoss()
criterion = FocalLoss(alpha = 1, gamma = 2, logits = False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.1, patience = 10)
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
...@@ -37,25 +46,32 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -37,25 +46,32 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs, true_masks in train_loader: for imgs, true_masks in train_loader:
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
optimizer.zero_grad() optimizer.zero_grad()
masks_pred = net(imgs) masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks) loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item() epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
global_step += 1
dice = eval_net(net, val_loader, device, n_val) dice = eval_net(net, val_loader, device, n_val)
jac = eval_jac(net,val_loader,device,n_val)
# overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val)
scheduler.step(dice) scheduler.step(dice)
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
logging.info(f'Avg Dice:{dice} Jaccard:{jac}\n' logging.info(f'Avg Dice:{dice}\t'
f'Learning Rate:{scheduler.get_lr()[0]}') f'Learning Rate:{lr}')
writer.add_scalar('Dice/test', dice, global_step)
writer.add_images('images', imgs, global_step)
# if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', masks_pred > 0.5, global_step)
if epoch % 5 == 0: if epoch % 5 == 0:
try: try:
os.mkdir(dir_checkpoint) os.mkdir(dir_checkpoint)
......
...@@ -8,14 +8,15 @@ import logging ...@@ -8,14 +8,15 @@ import logging
import os import os
import re import re
def save_img_mask(img,mask,dir,name):
def save_img_mask(img, mask, dir, name):
plt.figure(dpi = 300) plt.figure(dpi = 300)
plt.suptitle(name) plt.suptitle(name)
plt.imshow(img,'gray') plt.imshow(img, 'gray')
#print(img.shape) # print(img.shape)
mask = cv.cvtColor(mask, cv.COLOR_GRAY2RGB) mask = cv.cvtColor(mask, cv.COLOR_GRAY2RGB)
mask[:,:,2] = 0 mask[:, :, 2] = 0
mask[:,:,0] = 0 mask[:, :, 0] = 0
plt.imshow(mask, alpha = 0.25, cmap = 'rainbow') plt.imshow(mask, alpha = 0.25, cmap = 'rainbow')
try: try:
os.makedirs('data/output/' + dir) os.makedirs('data/output/' + dir)
...@@ -50,14 +51,15 @@ def get_subarea_info(img, mask): ...@@ -50,14 +51,15 @@ def get_subarea_info(img, mask):
group = np.where(labels == i) group = np.where(labels == i)
area_size = len(group[0]) area_size = len(group[0])
#if area_size > 10: # 过小的区域直接剔除 # if area_size > 10: # 过小的区域直接剔除
area_value = img[group] area_value = img[group]
area_mean = np.mean(area_value) area_mean = np.mean(area_value)
# Background Value # Background Value
pos = [(group[0][k], group[1][k]) for k in range(len(group[0]))] pos = [(group[0][k], group[1][k]) for k in range(len(group[0]))]
area_points = np.array([mask[x, y] if (x, y) in pos else 0 for x in range(200) for y in range(200)], dtype = np.uint8).reshape([200,200]) area_points = np.array([mask[x, y] if (x, y) in pos else 0 for x in range(200) for y in range(200)],
dtype = np.uint8).reshape([200, 200])
kernel = np.ones((15, 15), np.uint8) kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.erode(area_points, kernel) bg_area_mask = cv.erode(area_points, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, 255 - area_points) surround_bg_mask = cv.bitwise_xor(bg_area_mask, 255 - area_points)
...@@ -65,8 +67,7 @@ def get_subarea_info(img, mask): ...@@ -65,8 +67,7 @@ def get_subarea_info(img, mask):
back_value = img[np.where(real_bg_mask != 0)] back_value = img[np.where(real_bg_mask != 0)]
back_mean = np.mean(back_value) back_mean = np.mean(back_value)
info.append({'mean': area_mean,'back':back_mean,'size':area_size}) info.append({'mean': area_mean, 'back': back_mean, 'size': area_size}) # endif
# endif
df = pd.DataFrame(info) df = pd.DataFrame(info)
median = np.median(df['mean']) median = np.median(df['mean'])
...@@ -77,7 +78,7 @@ def get_subarea_info(img, mask): ...@@ -77,7 +78,7 @@ def get_subarea_info(img, mask):
df = df[df['mean'] >= lower_limit] df = df[df['mean'] >= lower_limit]
df = df[df['mean'] <= upper_limit] df = df[df['mean'] <= upper_limit]
df['value'] = df['mean']-df['back'] df['value'] = df['mean'] - df['back']
return (df['value'] * df['size']).sum() / df['size'].sum(), df.shape[0] return (df['value'] * df['size']).sum() / df['size'].sum(), df.shape[0]
...@@ -97,31 +98,17 @@ def get_subarea_info_avgBG(img, mask): ...@@ -97,31 +98,17 @@ def get_subarea_info_avgBG(img, mask):
area_mean = np.mean(area_value) area_mean = np.mean(area_value)
size += area_size size += area_size
value += (area_mean-bg)*area_size value += (area_mean - bg) * area_size
return value / size return value / size
def get_subarea_info_fast(img, mask): def get_subarea_info_fast_outlier(img, mask):
if mask.max() == 0: if mask.max() == 0:
return None return None
else: else:
kernel = np.ones((15, 15), np.uint8) kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel) bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask) surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_value = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
return sig_value - back_value
def get_subarea_info_fast_outlier(img,mask):
if mask.max() == 0:
return None
else:
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_mean = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
median = np.median(img[np.where(mask != 0)]) median = np.median(img[np.where(mask != 0)])
b = 1.4826 b = 1.4826
...@@ -131,14 +118,35 @@ def get_subarea_info_fast_outlier(img,mask): ...@@ -131,14 +118,35 @@ def get_subarea_info_fast_outlier(img,mask):
bg_median = np.median(img[np.where(surround_bg_mask != 0)]) bg_median = np.median(img[np.where(surround_bg_mask != 0)])
bg_mad = b * np.median(np.abs(img[np.where(surround_bg_mask != 0)] - bg_median)) bg_mad = b * np.median(np.abs(img[np.where(surround_bg_mask != 0)] - bg_median))
bg_lower_limit = bg_median-(3*bg_mad) bg_lower_limit = bg_median - (3 * bg_mad)
bg_upper_limit = bg_median+(3*bg_mad) bg_upper_limit = bg_median + (3 * bg_mad)
res = img[np.where(mask != 0)] res = img[np.where(mask != 0)]
res = res[res>=lower_limit] res = res[res >= lower_limit]
res = res[res<=upper_limit] res = res[res <= upper_limit]
bg = img[np.where(surround_bg_mask != 0)] bg = img[np.where(surround_bg_mask != 0)]
bg = bg[bg>=bg_lower_limit] bg = bg[bg >= bg_lower_limit]
bg = bg[bg<=bg_upper_limit] bg = bg[bg <= bg_upper_limit]
return np.mean(res) - np.mean(bg) return np.mean(res) - np.mean(bg)
def get_subarea_info_nobg(img,mask):
if mask.max() == 0:
return None
else:
median = np.median(img[np.where(mask != 0)])
b = 1.4826
mad = b * np.median(np.abs(img[np.where(mask != 0)] - median))
lower_limit = median - (3 * mad)
upper_limit = median + (3 * mad)
res = img[np.where(mask != 0)]
res = res[res >= lower_limit]
res = res[res <= upper_limit]
return np.mean(res)
# todo: def get_subarea_info_fast_outlier_cluster(img,mask):
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import sys import sys
import torch import torch
from torch import nn
import unet import unet
import mrnet import mrnet
...@@ -49,9 +50,12 @@ if __name__ == '__main__': ...@@ -49,9 +50,12 @@ if __name__ == '__main__':
#net = UNet(n_channels = 1, n_classes = 1) #net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1) net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n' if torch.cuda.device_count() > 1:
f'\t{net.n_channels} input channels\n' net = nn.DataParallel(net)
f'\t{net.n_classes} output channels (classes)\n')
# logging.info(f'Network:\n'
# f'\t{net.n_channels} input channels\n'
# f'\t{net.n_classes} output channels (classes)\n')
if args.load: if args.load:
net.load_state_dict(torch.load(args.load, map_location = device)) net.load_state_dict(torch.load(args.load, map_location = device))
...@@ -62,7 +66,7 @@ if __name__ == '__main__': ...@@ -62,7 +66,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
unet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr) mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt') logging.info('Saved interrupt')
......
...@@ -17,6 +17,8 @@ from utils.eval import eval_net ...@@ -17,6 +17,8 @@ from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset,VOCSegmentation from utils.dataset import BasicDataset,VOCSegmentation
from utils.dice_loss import DiceLoss
from utils.focal_loss import FocalLoss
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/' dir_img = 'data/train_imgs/'
...@@ -48,11 +50,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -48,11 +50,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8) # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum=0.99) optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum=0.99)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.1, patience = 8)
if net.n_classes > 1: # if net.n_classes > 1:
criterion = nn.CrossEntropyLoss() # criterion = nn.CrossEntropyLoss()
else: # else:
criterion = nn.BCEWithLogitsLoss() criterion = DiceLoss()#FocalLoss(alpha = .75, gamma = 2,logits = True)#nn.BCEWithLogitsLoss()
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
...@@ -60,11 +62,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -60,11 +62,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs,true_masks in train_loader: for imgs,true_masks in train_loader:
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs) masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks) loss = criterion(torch.sigmoid(masks_pred), true_masks)
epoch_loss += loss.item() epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step) writer.add_scalar('Loss/train', loss.item(), global_step)
...@@ -78,19 +80,19 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -78,19 +80,19 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
scheduler.step(val_score) scheduler.step(val_score)
lr = optimizer.param_groups[0]['lr'] lr = optimizer.param_groups[0]['lr']
if net.n_classes > 1: # if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score)) # logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step) # writer.add_scalar('Loss/test', val_score, global_step)
else: # else:
logging.info('Validation Dice Coeff: {} lr:{}'.format(val_score,lr)) logging.info('Validation Dice Coeff: {} lr:{}'.format(val_score,lr))
writer.add_scalar('Dice/test', val_score, global_step) writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step) writer.add_images('images', imgs, global_step)
if net.n_classes == 1: #if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 50 == 0: if save_cp and (epoch+1) % 10 == 0:
try: try:
os.mkdir(dir_checkpoint) os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory') logging.info('Created checkpoint directory')
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function
import torch.nn as nn
class DiceCoeff(Function): class DiceCoeff(Function):
"""Dice coeff for individual examples""" """Dice coeff for individual examples"""
...@@ -37,11 +38,21 @@ def dice_coeff(input, target): ...@@ -37,11 +38,21 @@ def dice_coeff(input, target):
return s / (i + 1) return s / (i + 1)
def dice_coef(pred, target): class DiceLoss(nn.Module):
smooth = 1. def __init__(self):
num = pred.size(0) super(DiceLoss, self).__init__()
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten def forward(self, input, target):
intersection = (m1 * m2).sum() N = target.size(0)
smooth = 1
input_flat = input.view(N, -1)
target_flat = target.view(N, -1)
intersection = input_flat * target_flat
loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
loss = 1 - loss.sum() / N
return loss
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
\ No newline at end of file
...@@ -4,7 +4,7 @@ from tqdm import tqdm ...@@ -4,7 +4,7 @@ from tqdm import tqdm
from sklearn.metrics import jaccard_score from sklearn.metrics import jaccard_score
import numpy as np import numpy as np
from utils.dice_loss import dice_coeff, dice_coef from utils.dice_loss import dice_coeff
def eval_net(net, loader, device, n_val): def eval_net(net, loader, device, n_val):
...@@ -15,17 +15,17 @@ def eval_net(net, loader, device, n_val): ...@@ -15,17 +15,17 @@ def eval_net(net, loader, device, n_val):
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar: with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
for imgs,true_masks in loader: for imgs,true_masks in loader:
imgs = imgs.to(device=device, dtype=torch.float32) imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type) true_masks = true_masks.to(device=device, dtype=mask_type)
mask_pred = net(imgs) mask_pred = net(imgs)
for true_mask, pred in zip(true_masks, mask_pred): for true_mask, pred in zip(true_masks, mask_pred):
pred = (pred > 0.5).float() pred = (pred > 0.5).float()
if net.n_classes > 1: #if net.n_classes > 1:
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item() #tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else: #else:
tot += dice_coef(pred, true_mask.squeeze(dim=1)).item() tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
return tot / n_val return tot / n_val
...@@ -37,7 +37,7 @@ def eval_jac(net, loader, device, n_val): ...@@ -37,7 +37,7 @@ def eval_jac(net, loader, device, n_val):
with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar: with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
for imgs, true_masks in loader: for imgs, true_masks in loader:
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
pred_masks = net(imgs) pred_masks = net(imgs)
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha = 1, gamma = 2, logits = False, reduce = True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce = False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce = False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment