Commit 74e4ef3c by 王肇一

0813,right before the meeting

parent 447cf1b5
......@@ -8,7 +8,7 @@ import torch
from PIL import Image
import cv2 as cv
from tqdm import tqdm
from multiprocessing import Pool
from joblib import delayed, Parallel
import math
import argparse
import logging
......@@ -65,6 +65,22 @@ def step_1_32bit(net,args,device,list,position):
logging.info("path already exist")
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_1_joblib(net,args,device,fn):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
print('data/imgs/' + fn[0] + '/' + fn[1])
mask = predict_img(net=net, full_img=img, out_threshold=args.mask_threshold, device=device)
# mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result = (mask * 255).astype(np.uint8)
# save_img({'ori': img, 'mask': result}, fn[0], fn[1])
save_img_mask(img, result, fn[0], fn[1])
try:
os.makedirs('data/masks/' + fn[0])
except:
logging.info("path already exist")
cv.imwrite('data/masks/' + fn[0] + '/' + fn[1], result)
def step_2(list, position=1):
for num, dir in enumerate(list):
......@@ -162,13 +178,15 @@ def cli_main():
path = [(y, x) for y in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs')) for x in filter(
lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith(
'dc .tif'), os.listdir('data/imgs/' + y))]
seperate_path = divide_list(path, args.process)
if args.step == 1:
net = UNet(n_channels = 1, n_classes = 1)
#net = MultiUnet(n_channels = 1,n_classes = 1)
#net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
logging.info(f'Using device {device}')
net.to(device = device)
......@@ -180,22 +198,23 @@ def cli_main():
net.load_state_dict(d)
#net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging.info("Model loaded !")
k = Parallel(n_jobs=8)(delayed(step_1_joblib)(net,args,device,fn) for fn in tqdm(path))
pool = Pool(args.process)
for i, list in enumerate(seperate_path):
pool.apply_async(step_1_32bit, args = (net, args, device, list, i))
pool.close()
pool.join()
# pool = Pool(args.process)
# for i, list in enumerate(seperate_path):
# pool.apply_async(step_1_32bit, args = (net, args, device, list, i))
# pool.close()
# pool.join()
elif args.step == 2:
dir = [x for x in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs/'))]
sep_dir = divide_list(dir, args.process)
pool = Pool(args.process)
for i, list in enumerate(sep_dir):
pool.apply_async(step_2_32bit, args = (list, i))
pool.close()
pool.join()
# pool = Pool(args.process)
# for i, list in enumerate(sep_dir):
# pool.apply_async(step_2_32bit, args = (list, i))
# pool.close()
# pool.join()
if __name__ == '__main__':
......
......@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset, VOCSegmentation
from utils.eval import eval_net, eval_jac
from utils.eval import eval_net#, eval_jac
from utils.dice_loss import DiceLoss
from utils.focal_loss import FocalLoss
......
......@@ -67,7 +67,7 @@ class VOCSegmentation(VisionDataset):
def __init__(self, root, image_set = 'train', transform = None, target_transform = None, transforms = None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
base_dir = 'enhance'
base_dir = 'voc'
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
......
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import jaccard_score
# from sklearn.metrics import jaccard_score
import numpy as np
from utils.dice_loss import dice_coeff
......@@ -31,20 +31,20 @@ def eval_net(net, loader, device, n_val):
return tot / n_val
def eval_jac(net, loader, device, n_val):
net.eval()
jac = 0
with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
for imgs, true_masks in loader:
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32
true_masks = true_masks.to(device = device, dtype = mask_type)
pred_masks = net(imgs)
pred_masks = torch.round(pred_masks).cpu().detach().numpy()
true_masks = torch.round(true_masks).cpu().numpy()
pred_masks = np.array([1 if x>0 else 0 for x in pred_masks])
jac += jaccard_score(true_masks.flatten(), pred_masks.flatten())
pbar.update(imgs.shape[0])
return jac/n_val
# def eval_jac(net, loader, device, n_val):
# net.eval()
# jac = 0
# with tqdm(total = n_val, desc = 'Validation round', unit = 'img', leave = False) as pbar:
# for imgs, true_masks in loader:
# imgs = imgs.to(device = device, dtype = torch.float32)
# mask_type = torch.float32
# true_masks = true_masks.to(device = device, dtype = mask_type)
# pred_masks = net(imgs)
#
# pred_masks = torch.round(pred_masks).cpu().detach().numpy()
# true_masks = torch.round(true_masks).cpu().numpy()
# pred_masks = np.array([1 if x>0 else 0 for x in pred_masks])
# jac += jaccard_score(true_masks.flatten(), pred_masks.flatten())
#
# pbar.update(imgs.shape[0])
# return jac/n_val
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment