文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 597d01e2 by 王肇一

enhance 10 times

parent 86d39706
...@@ -18,7 +18,8 @@ import re ...@@ -18,7 +18,8 @@ import re
from unet import UNet from unet import UNet
from mrnet import MultiUnet from mrnet import MultiUnet
from utils.predict import predict_img,predict from utils.predict import predict_img,predict
from resCalc import save_img, get_subarea_info, save_img_mask,get_subarea_info_avgBG from resCalc import save_img, get_subarea_info, save_img_mask,get_subarea_info_avgBG, get_subarea_info_fast, \
get_subarea_info_fast_outlier
def divide_list(list, step): def divide_list(list, step):
...@@ -30,8 +31,8 @@ def step_1(net, args, device, list, position): ...@@ -30,8 +31,8 @@ def step_1(net, args, device, list, position):
for fn in tqdm(list, position = position): for fn in tqdm(list, position = position):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1])) logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1]) img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
#mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold, device = device) mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device) #mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8) result = (mask * 255).astype(np.uint8)
#save_img({'ori': img, 'mask': result}, fn[0], fn[1]) #save_img({'ori': img, 'mask': result}, fn[0], fn[1])
...@@ -43,6 +44,27 @@ def step_1(net, args, device, list, position): ...@@ -43,6 +44,27 @@ def step_1(net, args, device, list, position):
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result) cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_1_32bit(net,args,device,list,position):
for fn in tqdm(list, position = position):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
#img = img.convert('L')
norm = cv.normalize(np.array(img), None, 0, 255, cv.NORM_MINMAX, cv.CV_8U)
norm = Image.fromarray(norm.astype('uint8'))
mask = predict_img(net = net, full_img = norm, out_threshold = args.mask_threshold, device = device)
# mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8)
# save_img({'ori': img, 'mask': result}, fn[0], fn[1])
#save_img_mask(img.convert('L'), result, fn[0], fn[1])
save_img_mask(norm, result, fn[0], fn[1])
try:
os.makedirs('data/masks/' + fn[0])
except:
logging.info("path already exist")
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_2(list, position=1): def step_2(list, position=1):
for num, dir in enumerate(list): for num, dir in enumerate(list):
#df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)')) #df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
...@@ -52,17 +74,18 @@ def step_2(list, position=1): ...@@ -52,17 +74,18 @@ def step_2(list, position=1):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name) match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, 0) img = cv.imread('data/imgs/' + dir + '/' + name, 0)
mask = cv.imread('data/masks/' + dir + '/' + name, 0) mask = cv.imread('data/masks/' + dir + '/' + name, 0)
# value = get_subarea_info_fast(img, mask)
value, count = get_subarea_info(img, mask) # value, count = get_subarea_info(img, mask)
# value = get_subarea_info_avgBG(img,mask) value = get_subarea_info_avgBG(img, mask)
ug = 0.0 if value is not None:
if str.lower(match_group.group(1)).endswith('ug'): ug = 0.0
ug = float(str.lower(match_group.group(1))[:-2]) if str.lower(match_group.group(1)).endswith('ug'):
elif str.lower(match_group.group(1)) == 'd2o': ug = float(str.lower(match_group.group(1))[:-2])
ug = 0 elif str.lower(match_group.group(1)) == 'd2o':
elif str.lower(match_group.group(1)) == 'lb': ug = 0
ug = -1 elif str.lower(match_group.group(1)) == 'lb':
values.append({'Intensity(a.u.)': value, 'ug': ug}) ug = -1
values.append({'Intensity(a.u.)': value, 'ug': ug})
df = pd.DataFrame(values) df = pd.DataFrame(values)
df.sort_values('ug', inplace = True) df.sort_values('ug', inplace = True)
...@@ -76,6 +99,41 @@ def step_2(list, position=1): ...@@ -76,6 +99,41 @@ def step_2(list, position=1):
plt.savefig('data/output/'+dir+'.png') plt.savefig('data/output/'+dir+'.png')
def step_2_32bit(list,position=1):
for num, dir in enumerate(list):
# df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values = []
names = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/' + dir))]
for name in tqdm(names, desc = f'Period{num + 1}/{len(list)}', position = position):
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.tif', name)
img = cv.imread('data/imgs/' + dir + '/' + name, flags = cv.IMREAD_ANYDEPTH)
mask = cv.imread('data/masks/' + dir + '/' + name, 0)
value = get_subarea_info_fast_outlier(img, mask)
#value,shape = get_subarea_info(img,mask)
if value is not None:
ug = 0.0
if str.lower(match_group.group(1)).endswith('ug'):
ug = float(str.lower(match_group.group(1))[:-2])
elif str.lower(match_group.group(1)) == 'd2o':
ug = 0
elif str.lower(match_group.group(1)) == 'lb':
ug = -1
values.append({'Intensity(a.u.)': value, 'ug': ug})
df = pd.DataFrame(values)
df.sort_values('ug', inplace = True)
df.replace(-1, 'lb', inplace = True)
df.replace(0, 'd2o', inplace = True)
baseline_high = df[df['ug'] == 'd2o']['Intensity(a.u.)'].mean() * 0.62
baseline_low = df[df['ug'] == 'd2o']['Intensity(a.u.)'].mean() * 0.70
sns.set_style("darkgrid")
sns.catplot(x = 'ug', y = 'Intensity(a.u.)', kind = 'bar', palette = 'vlag', data = df)
# sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
plt.axhline(y = baseline_high)
plt.axhline(y = baseline_low)
plt.savefig('data/output/' + dir + '.png')
def get_args(): def get_args():
parser = argparse.ArgumentParser(description = 'A simple toolkit designed by Ulden', parser = argparse.ArgumentParser(description = 'A simple toolkit designed by Ulden',
formatter_class = argparse.ArgumentDefaultsHelpFormatter) formatter_class = argparse.ArgumentDefaultsHelpFormatter)
...@@ -102,8 +160,8 @@ def cli_main(): ...@@ -102,8 +160,8 @@ def cli_main():
seperate_path = divide_list(path, args.process) seperate_path = divide_list(path, args.process)
if args.step == 1: if args.step == 1:
# net = UNet(n_channels = 1, n_classes = 1) net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1,n_classes = 1) #net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module)) logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
...@@ -113,7 +171,7 @@ def cli_main(): ...@@ -113,7 +171,7 @@ def cli_main():
pool = Pool(args.process) pool = Pool(args.process)
for i, list in enumerate(seperate_path): for i, list in enumerate(seperate_path):
pool.apply_async(step_1, args = (net, args, device, list, i)) pool.apply_async(step_1_32bit, args = (net, args, device, list, i))
pool.close() pool.close()
pool.join() pool.join()
...@@ -123,6 +181,28 @@ def cli_main(): ...@@ -123,6 +181,28 @@ def cli_main():
pool = Pool(args.process) pool = Pool(args.process)
for i, list in enumerate(sep_dir): for i, list in enumerate(sep_dir):
pool.apply_async(step_2_32bit, args = (list, i))
pool.close()
pool.join()
elif args.step == 3:
net = UNet(n_channels = 1, n_classes = 1)
# net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device = device)
net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !")
pool = Pool(args.process)
for i, list in enumerate(seperate_path):
pool.apply_async(step_1, args = (net, args, device, list, i))
pool.close()
pool.join()
dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))]
sep_dir = divide_list(dir, args.process)
for i, list in enumerate(sep_dir):
pool.apply_async(step_2, args = (list, i)) pool.apply_async(step_2, args = (list, i))
pool.close() pool.close()
pool.join() pool.join()
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -29,7 +29,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -29,7 +29,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
optimizer = optim.Adam(net.parameters(), lr = lr) optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCELoss()# nn.BCEWithLogitsLoss() criterion = nn.BCELoss()# nn.BCEWithLogitsLoss()
scheduler = lr_scheduler.StepLR(optimizer, 30, 0.5)# lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5)
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
...@@ -52,7 +52,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1): ...@@ -52,7 +52,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
dice = eval_net(net, val_loader, device, n_val) dice = eval_net(net, val_loader, device, n_val)
jac = eval_jac(net,val_loader,device,n_val) jac = eval_jac(net,val_loader,device,n_val)
# overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val) # overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val)
scheduler.step() scheduler.step(dice)
lr = optimizer.param_groups[0]['lr']
logging.info(f'Avg Dice:{dice} Jaccard:{jac}\n' logging.info(f'Avg Dice:{dice} Jaccard:{jac}\n'
f'Learning Rate:{scheduler.get_lr()[0]}') f'Learning Rate:{scheduler.get_lr()[0]}')
if epoch % 5 == 0: if epoch % 5 == 0:
......
...@@ -82,17 +82,63 @@ def get_subarea_info(img, mask): ...@@ -82,17 +82,63 @@ def get_subarea_info(img, mask):
return (df['value'] * df['size']).sum() / df['size'].sum(), df.shape[0] return (df['value'] * df['size']).sum() / df['size'].sum(), df.shape[0]
def get_subarea_info_avgBG(img,mask): def get_subarea_info_avgBG(img, mask):
area_num, labels, stats, centroids = cv.connectedComponentsWithStats(mask, connectivity = 8) if mask.max() == 0:
value = 0 return None
size = 0 else:
bg = np.mean(img[np.where(labels == 0)]) area_num, labels, stats, centroids = cv.connectedComponentsWithStats(mask, connectivity = 8)
for i in filter(lambda x: x != 0, range(area_num)): value = 0
group = np.where(labels == i) size = 0
area_size = len(group[0]) bg = np.mean(img[np.where(labels == 0)])
area_value = img[group] for i in filter(lambda x: x != 0, range(area_num)):
area_mean = np.mean(area_value) group = np.where(labels == i)
area_size = len(group[0])
area_value = img[group]
area_mean = np.mean(area_value)
size += area_size
value += (area_mean-bg)*area_size
return value / size
def get_subarea_info_fast(img, mask):
if mask.max() == 0:
return None
else:
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_value = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
return sig_value - back_value
size += area_size
value += (area_mean-bg)*area_size def get_subarea_info_fast_outlier(img,mask):
return value / size if mask.max() == 0:
\ No newline at end of file return None
else:
kernel = np.ones((15, 15), np.uint8)
bg_area_mask = cv.dilate(mask, kernel)
surround_bg_mask = cv.bitwise_xor(bg_area_mask, mask)
sig_mean = np.mean(img[np.where(mask != 0)])
back_value = np.mean(img[np.where(surround_bg_mask != 0)])
median = np.median(img[np.where(mask != 0)])
b = 1.4826
mad = b * np.median(np.abs(img[np.where(mask != 0)] - median))
lower_limit = median - (3 * mad)
upper_limit = median + (3 * mad)
bg_median = np.median(img[np.where(surround_bg_mask != 0)])
bg_mad = b * np.median(np.abs(img[np.where(surround_bg_mask != 0)] - bg_median))
bg_lower_limit = bg_median-(3*bg_mad)
bg_upper_limit = bg_median+(3*bg_mad)
res = img[np.where(mask != 0)]
res = res[res>=lower_limit]
res = res[res<=upper_limit]
bg = img[np.where(surround_bg_mask != 0)]
bg = bg[bg>=bg_lower_limit]
bg = bg[bg<=bg_upper_limit]
return np.mean(res) - np.mean(bg)
\ No newline at end of file
...@@ -46,7 +46,7 @@ if __name__ == '__main__': ...@@ -46,7 +46,7 @@ if __name__ == '__main__':
# - For 2 classes, use n_classes=1 # - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N # - For N > 2 classes, use n_classes=N
# net = UNet(n_channels = 1, n_classes = 1) #net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1, n_classes = 1) net = MultiUnet(n_channels = 1, n_classes = 1)
logging.info(f'Network:\n' logging.info(f'Network:\n'
...@@ -62,7 +62,7 @@ if __name__ == '__main__': ...@@ -62,7 +62,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr) unet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt') logging.info('Saved interrupt')
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from torch import optim from torch import optim
from torchvision import transforms from torchvision import transforms
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
from torch.optim.rmsprop import RMSprop
from tqdm import tqdm from tqdm import tqdm
from utils.eval import eval_net from utils.eval import eval_net
...@@ -46,9 +47,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -46,9 +47,8 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
''') ''')
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8) # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8) optimizer = RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8, momentum=0.99)
#scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min') scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'max', eps = 1e-20, factor = 0.5, patience = 5)
scheduler = lr_scheduler.CyclicLR(optimizer, base_lr = 1e-10, max_lr = 0.01)
if net.n_classes > 1: if net.n_classes > 1:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
else: else:
...@@ -73,16 +73,16 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -73,16 +73,16 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
scheduler.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0: global_step += 1 # if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
#scheduler.step(val_score) scheduler.step(val_score)
lr = optimizer.param_groups[0]['lr']
if net.n_classes > 1: if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step) writer.add_scalar('Loss/test', val_score, global_step)
else: else:
logging.info('Validation Dice Coeff: {}'.format(val_score)) logging.info('Validation Dice Coeff: {} lr:{}'.format(val_score,lr))
writer.add_scalar('Dice/test', val_score, global_step) writer.add_scalar('Dice/test', val_score, global_step)
writer.add_images('images', imgs, global_step) writer.add_images('images', imgs, global_step)
...@@ -90,7 +90,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True) ...@@ -90,7 +90,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 5 == 0: if save_cp and epoch % 50 == 0:
try: try:
os.mkdir(dir_checkpoint) os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory') logging.info('Created checkpoint directory')
......
...@@ -67,7 +67,7 @@ class VOCSegmentation(VisionDataset): ...@@ -67,7 +67,7 @@ class VOCSegmentation(VisionDataset):
def __init__(self, root, image_set = 'train', transform = None, target_transform = None, transforms = None): def __init__(self, root, image_set = 'train', transform = None, target_transform = None, transforms = None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform) super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
base_dir = 'voc' base_dir = 'enhance'
voc_root = os.path.join(self.root, base_dir) voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages') image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG') mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment