Commit 447cf1b5 by 王肇一

202008 update

parent 597d01e2
......@@ -14,12 +14,13 @@ import argparse
import logging
import os
import re
from collections import OrderedDict
from unet import UNet
from mrnet import MultiUnet
from utils.predict import predict_img,predict
from resCalc import save_img, get_subarea_info, save_img_mask,get_subarea_info_avgBG, get_subarea_info_fast, \
get_subarea_info_fast_outlier
from resCalc import save_img, get_subarea_info, save_img_mask, get_subarea_info_avgBG, get_subarea_info_fast_outlier, \
get_subarea_info_nobg
def divide_list(list, step):
......@@ -52,7 +53,7 @@ def step_1_32bit(net,args,device,list,position):
norm = cv.normalize(np.array(img), None, 0, 255, cv.NORM_MINMAX, cv.CV_8U)
norm = Image.fromarray(norm.astype('uint8'))
mask = predict_img(net = net, full_img = norm, out_threshold = args.mask_threshold, device = device)
# mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
#mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8)
# save_img({'ori': img, 'mask': result}, fn[0], fn[1])
......@@ -74,8 +75,6 @@ def step_2(list, position=1):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, 0)
mask = cv.imread('data/masks/' + dir + '/' + name, 0)
# value = get_subarea_info_fast(img, mask)
# value, count = get_subarea_info(img, mask)
value = get_subarea_info_avgBG(img, mask)
if value is not None:
ug = 0.0
......@@ -99,7 +98,7 @@ def step_2(list, position=1):
plt.savefig('data/output/'+dir+'.png')
def step_2_32bit(list,position=1):
def step_2_32bit(list, position=1):
for num, dir in enumerate(list):
# df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values = []
......@@ -107,9 +106,14 @@ def step_2_32bit(list,position=1):
for name in tqdm(names, desc = f'Period{num + 1}/{len(list)}', position = position):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, flags = cv.IMREAD_ANYDEPTH)
# img_mean = np.mean(img)
# if np.max(img) + 125 -img_mean < 255:
# img = img+ 125 - img_mean
# else:
# img = img+ 255 - np.max(img)
mask = cv.imread('data/masks/' + dir + '/' + name, 0)
value = get_subarea_info_fast_outlier(img, mask)
#value,shape = get_subarea_info(img,mask)
value = get_subarea_info_fast_outlier(img, mask)# get_subarea_info_nobg(img,mask)
if value is not None:
ug = 0.0
if str.lower(match_group.group(1)).endswith('ug'):
......@@ -129,6 +133,7 @@ def step_2_32bit(list,position=1):
sns.set_style("darkgrid")
sns.catplot(x = 'ug', y = 'Intensity(a.u.)', kind = 'bar', palette = 'vlag', data = df)
# sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
plt.suptitle(dir)
plt.axhline(y = baseline_high)
plt.axhline(y = baseline_low)
plt.savefig('data/output/' + dir + '.png')
......@@ -166,7 +171,14 @@ def cli_main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device = device)
net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
model = torch.load('data/module/' + args.module + '.pth',map_location = device)
d = OrderedDict()
for key, value in model.items():
tmp = key[7:]
d[tmp] = value
net.load_state_dict(d)
#net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !")
pool = Pool(args.process)
......@@ -184,28 +196,6 @@ def cli_main():
pool.apply_async(step_2_32bit, args = (list, i))
pool.close()
pool.join()
elif args.step == 3:
net = UNet(n_channels = 1, n_classes = 1)
# net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device = device)
net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !")
pool = Pool(args.process)
for i, list in enumerate(seperate_path):
pool.apply_async(step_1, args = (net, args, device, list, i))
pool.close()
pool.join()
dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))]
sep_dir = divide_list(dir, args.process)
for i, list in enumerate(sep_dir):
pool.apply_async(step_2, args = (list, i))
pool.close()
pool.join()
if __name__ == '__main__':
......
......@@ -42,8 +42,8 @@ class MultiUnet(nn.Module):
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Sequential(
nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
nn.Softmax()
#nn.Sigmoid()
#nn.Softmax()
nn.Sigmoid()
)
# self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
......
......@@ -4,16 +4,20 @@
import os
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler
from torch.optim.rmsprop import RMSprop
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset, VOCSegmentation
from utils.eval import eval_net, eval_jac
from utils.dice_loss import DiceLoss
from utils.focal_loss import FocalLoss
dir_checkpoint = 'checkpoint/'
......@@ -27,9 +31,14 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()# nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}')
global_step = 0
#optimizer = optim.Adam(net.parameters(), lr = lr)
optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum = 0.99)
#criterion = nn.BCELoss()
criterion = FocalLoss(alpha = 1, gamma = 2, logits = False)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.1, patience = 10)
for epoch in range(epochs):
net.train()
......@@ -37,25 +46,32 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs, true_masks in train_loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
optimizer.zero_grad()
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
pbar.set_postfix(**{'loss (batch)': loss.item()})
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
dice = eval_net(net, val_loader, device, n_val)
jac = eval_jac(net,val_loader,device,n_val)
# overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val)
scheduler.step(dice)
lr = optimizer.param_groups[0]['lr']
logging.info(f'Avg Dice:{dice} Jaccard:{jac}\n'
f'Learning Rate:{scheduler.get_lr()[0]}')
logging.info(f'Avg Dice:{dice}\t'
f'Learning Rate:{lr}')
writer.add_scalar('Dice/test', dice, global_step)
writer.add_images('images', imgs, global_step)
# if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', masks_pred > 0.5, global_step)
if epoch % 5 == 0:
try:
os.mkdir(dir_checkpoint)
......
......@@ -8,14 +8,15 @@ import logging
import os
import re
def save_img_mask(img,mask,dir,name):
def save_img_mask(img, mask, dir, name):
plt.figure(dpi = 300)
plt.suptitle(name)
plt.imshow(img,'gray')
#print(img.shape)
plt.imshow(img, 'gray')
# print(img.shape)
mask = cv.cvtColor(mask, cv.COLOR_GRAY2RGB)
mask[:,:,2] = 0
mask[:,:,0] = 0
mask[:, :, 2] = 0
mask[:, :, 0] = 0
plt.imshow(mask, alpha = 0.25, cmap = 'rainbow')
try:
os.makedirs('data/output/' + dir)
......@@ -50,14 +51,15 @@ def get_subarea_info(img, mask):
group = np.where(labels == i)
area_size = len(group[0])
#if area_size > 10: # 过小的区域直接剔除
# if area_size > 10: # 过小的区域直接剔除
area_value = img[group]
area_mean = np.mean(area_value)
# Background Value
pos = [(group[0][k], group[1][k]) for k in range(len(group[0]))]
area_points = np.array([mask[x, y] if (x, y) in pos else 0 for x in range(200) for y in range(200)], dtype = np.uint8).reshape([200,200])
area_points = np.array([mask[x, y] if (x, y) in pos else 0 for x in range(200) for y in range(200)],
dtype = np.uint8).reshape([200, 200])
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.erode(area_points, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, 255 - area_points)
......@@ -65,8 +67,7 @@ def get_subarea_info(img, mask):
back_value = img[np.where(real_bg_mask != 0)]
back_mean = np.mean(back_value)
info.append({'mean': area_mean,'back':back_mean,'size':area_size})
# endif
info.append({'mean': area_mean, 'back': back_mean, 'size': area_size}) # endif
df = pd.DataFrame(info)
median = np.median(df['mean'])
......@@ -77,7 +78,7 @@ def get_subarea_info(img, mask):
df = df[df['mean'] >= lower_limit]
df = df[df['mean'] <= upper_limit]
df['value'] = df['mean']-df['back']
df['value'] = df['mean'] - df['back']
return (df['value'] * df['size']).sum() / df['size'].sum(), df.shape[0]
......@@ -97,31 +98,17 @@ def get_subarea_info_avgBG(img, mask):
area_mean = np.mean(area_value)
size += area_size
value += (area_mean-bg)*area_size
value += (area_mean - bg) * area_size
return value / size
def get_subarea_info_fast(img, mask):
def get_subarea_info_fast_outlier(img, mask):
if mask.max() == 0:
return None
else:
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_value = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
return sig_value - back_value
def get_subarea_info_fast_outlier(img,mask):
if mask.max() == 0:
return None
else:
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_mean = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
median = np.median(img[np.where(mask != 0)])
b = 1.4826
......@@ -131,14 +118,35 @@ def get_subarea_info_fast_outlier(img,mask):
bg_median = np.median(img[np.where(surround_bg_mask != 0)])
bg_mad = b * np.median(np.abs(img[np.where(surround_bg_mask != 0)] - bg_median))
bg_lower_limit = bg_median-(3*bg_mad)
bg_upper_limit = bg_median+(3*bg_mad)
bg_lower_limit = bg_median - (3 * bg_mad)
bg_upper_limit = bg_median + (3 * bg_mad)
res = img[np.where(mask != 0)]
res = res[res>=lower_limit]
res = res[res<=upper_limit]
res = res[res >= lower_limit]
res = res[res <= upper_limit]
bg = img[np.where(surround_bg_mask != 0)]
bg = bg[bg>=bg_lower_limit]
bg = bg[bg<=bg_upper_limit]
return np.mean(res) - np.mean(bg)
\ No newline at end of file
bg = bg[bg >= bg_lower_limit]
bg = bg[bg <= bg_upper_limit]
return np.mean(res) - np.mean(bg)
def get_subarea_info_nobg(img,mask):
if mask.max() == 0:
return None
else:
median = np.median(img[np.where(mask != 0)])
b = 1.4826
mad = b * np.median(np.abs(img[np.where(mask != 0)] - median))
lower_limit = median - (3 * mad)
upper_limit = median + (3 * mad)
res = img[np.where(mask != 0)]
res = res[res >= lower_limit]
res = res[res <= upper_limit]
return np.mean(res)
# todo: def get_subarea_info_fast_outlier_cluster(img,mask):
......@@ -4,6 +4,7 @@ import os
import sys
import torch
from torch import nn
import unet
import mrnet
......@@ -49,9 +50,12 @@ if __name__ == '__main__':
#net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n')
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
# logging.info(f'Network:\n'
# f'\t{net.n_channels} input channels\n'
# f'\t{net.n_classes} output channels (classes)\n')
if args.load:
net.load_state_dict(torch.load(args.load, map_location = device))
......@@ -62,7 +66,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try:
unet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
......
......@@ -17,6 +17,8 @@ from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset,VOCSegmentation
from utils.dice_loss import DiceLoss
from utils.focal_loss import FocalLoss
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
......@@ -48,11 +50,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum=0.99)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5)
if net.n_classes > 1:
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.1, patience = 8)
# if net.n_classes > 1:
# criterion = nn.CrossEntropyLoss()
# else:
criterion = DiceLoss()#FocalLoss(alpha = .75, gamma = 2,logits = True)#nn.BCEWithLogitsLoss()
for epoch in range(epochs):
net.train()
......@@ -60,11 +62,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs,true_masks in train_loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
loss = criterion(torch.sigmoid(masks_pred), true_masks)
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
......@@ -78,19 +80,19 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
val_score = eval_net(net, val_loader, device, n_val)
scheduler.step(val_score)
lr = optimizer.param_groups[0]['lr']
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {} lr:{}'.format(val_score,lr))
writer.add_scalar('Dice/test', val_score, global_step)
# if net.n_classes > 1:
# logging.info('Validation cross entropy: {}'.format(val_score))
# writer.add_scalar('Loss/test', val_score, global_step)
# else:
logging.info('Validation Dice Coeff: {} lr:{}'.format(val_score,lr))
writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
#if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 50 == 0:
if save_cp and (epoch+1) % 10 == 0:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
......
import torch
from torch.autograd import Function
import torch.nn as nn
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
......@@ -37,11 +38,21 @@ def dice_coeff(input, target):
return s / (i + 1)
def dice_coef(pred, target):
smooth = 1.
num = pred.size(0)
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten
intersection = (m1 * m2).sum()
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, input, target):
N = target.size(0)
smooth = 1
input_flat = input.view(N, -1)
target_flat = target.view(N, -1)
intersection = input_flat * target_flat
loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
loss = 1 - loss.sum() / N
return loss
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
\ No newline at end of file
......@@ -4,7 +4,7 @@ from tqdm import tqdm
from sklearn.metrics import jaccard_score
import numpy as np
from utils.dice_loss import dice_coeff, dice_coef
from utils.dice_loss import dice_coeff
def eval_net(net, loader, device, n_val):
......@@ -15,17 +15,17 @@ def eval_net(net, loader, device, n_val):
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
for imgs,true_masks in loader:
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
mask_type = torch.float32# if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
mask_pred = net(imgs)
for true_mask, pred in zip(true_masks, mask_pred):
pred = (pred > 0.5).float()
if net.n_classes > 1:
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else:
tot += dice_coef(pred, true_mask.squeeze(dim=1)).item()
#if net.n_classes > 1:
#tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
#else:
tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
pbar.update(imgs.shape[0])
return tot / n_val
......@@ -37,7 +37,7 @@ def eval_jac(net, loader, device, n_val):
with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
for imgs, true_masks in loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
mask_type = torch.float32
true_masks = true_masks.to(device = device, dtype = mask_type)
pred_masks = net(imgs)
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha = 1, gamma = 2, logits = False, reduce = True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce = False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce = False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment