Commit 1fc01530 by 王肇一

mrnet with VOC type dataset

parent 2a781f86
__pycache__/
.idea/*
data/imgs/*
data/masks/*
data/output/*
data/train_imgs/*
data/train_masks/*
data/train_imgs_32/*
data
.ipynb_checkpoints/
runs
checkpoint
\ No newline at end of file
......@@ -16,6 +16,7 @@ import os
import re
from unet import UNet
from mrnet import MultiUnet
from utils.predict import predict_img
from resCalc import save_img, get_subarea_info, save_img_mask
......@@ -29,7 +30,7 @@ def step_1(net, args, device, list, position):
for fn in tqdm(list, position = position):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
mask = predict_img(net = net, full_img = img, scale_factor = args.scale, out_threshold = args.mask_threshold,
mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold,
device = device)
result = (mask * 255).astype(np.uint8)
......@@ -100,7 +101,8 @@ def cli_main():
seperate_path = divide_list(path, args.process)
if args.step == 1:
net = UNet(n_channels = 1, n_classes = 1)
# net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
......
......@@ -15,15 +15,11 @@ import labelme
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('input_dir', help='input annotated directory')
parser.add_argument('output_dir', help='output dataset directory')
parser.add_argument('--labels', help='labels file', required=True)
parser.add_argument(
'--noviz', help='no visualization', action='store_true'
)
parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('input_dir', help = 'input annotated directory')
parser.add_argument('output_dir', help = 'output dataset directory')
parser.add_argument('--labels', help = 'labels file', required = True)
parser.add_argument('--noviz', help = 'no visualization', action = 'store_true')
args = parser.parse_args()
if osp.exists(args.output_dir):
......@@ -34,9 +30,7 @@ def main():
os.makedirs(osp.join(args.output_dir, 'SegmentationClass'))
os.makedirs(osp.join(args.output_dir, 'SegmentationClassPNG'))
if not args.noviz:
os.makedirs(
osp.join(args.output_dir, 'SegmentationClassVisualization')
)
os.makedirs(osp.join(args.output_dir, 'SegmentationClassVisualization'))
print('Creating dataset:', args.output_dir)
class_names = []
......@@ -58,47 +52,49 @@ def main():
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
base = osp.splitext(osp.basename(label_file))[0]
out_img_file = osp.join(
args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join(
args.output_dir, 'SegmentationClass', base + '.npy')
out_png_file = osp.join(
args.output_dir, 'SegmentationClassPNG', base + '.png')
if not args.noviz:
out_viz_file = osp.join(
args.output_dir,
'SegmentationClassVisualization',
base + '.jpg',
)
data = json.load(f)
img_file = osp.join(osp.dirname(label_file), data['imagePath'])
img = np.asarray(PIL.Image.open(img_file))
PIL.Image.fromarray(img).save(out_img_file)
lbl = labelme.utils.shapes_to_label(
img_shape=img.shape,
shapes=data['shapes'],
label_name_to_value=class_name_to_id,
)
labelme.utils.lblsave(out_png_file, lbl)
np.save(out_lbl_file, lbl)
if not args.noviz:
viz = imgviz.label2rgb(
label=lbl,
img=img,
font_size=15,
label_names=class_names,
loc='rb',
)
imgviz.io.imsave(out_viz_file, viz)
for path in os.listdir(args.input_dir):
for label_file in glob.glob(osp.join(args.input_dir + '/' + path, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
base = osp.splitext(osp.basename(label_file))[0]
out_img_file = osp.join(args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join(args.output_dir, 'SegmentationClass', base + '.npy')
out_png_file = osp.join(args.output_dir, 'SegmentationClassPNG', base + '.png')
if not args.noviz:
out_viz_file = osp.join(args.output_dir, 'SegmentationClassVisualization', base + '.jpg', )
data = json.load(f)
img_file = osp.join(osp.dirname(label_file), data['imagePath'])
img = np.asarray(PIL.Image.open(img_file))
PIL.Image.fromarray(img).save(out_img_file)
lbl = labelme.utils.shapes_to_label(img_shape = img.shape, shapes = data['shapes'],
label_name_to_value = class_name_to_id, )
lblsave(out_png_file, lbl)
np.save(out_lbl_file, lbl)
if not args.noviz:
viz = imgviz.label2rgb(label = lbl, img = img, font_size = 15, label_names = class_names,
loc = 'rb', )
imgviz.io.imsave(out_viz_file, viz)
def lblsave(filename, lbl):
if osp.splitext(filename)[1] != '.png':
filename += '.png'
# Assume label ranses [-1, 254] for int32,
# and [0, 255] for uint8 as VOC.
if lbl.min() >= -1 and lbl.max() < 255:
lbl = np.array([1 if lbl[x, y] > 0 else 0 for x in range(200) for y in range(200)]).reshape([200, 200])
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode = 'P')
colormap = imgviz.label_colormap()
lbl_pil.putpalette(colormap.flatten())
lbl_pil.save(filename)
else:
raise ValueError('[%s] Cannot save the pixel-wise class label as PNG. '
'Please consider using the .npy format.' % filename)
if __name__ == '__main__':
......
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 002
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 002
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Tobramycin 4ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 1ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 Clinical Sample K. p Tobramycin 64ug 852nm 30mw 300mw tune 43.00 004
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a 1h lb 852nm 30mw 300mw tune 43.03 005
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 4ug 852nm 30mw 300mw tune 43.06 001
Step size0Dwell time50 K.p atcc 700603 12.16 LB 40mw 400mw tune 43.08 002
Step size0Dwell time50 E.coil atcc25922 Ceftazime 1ug 852nm 40mw 400mw tune43.08 60x oil obj 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 S.a atcc 29213 Ampicillin 1ug 852nm 40mw 400mw tune 43.08 004
Step size0Dwell time50 E.coil atcc25922 Imipenen 1ug 852nm 40mw 400mw tune43.08 60x oil obj 001
Step size0Dwell time50 K.p atcc 700603 12.16 Ceftazidime 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Levofloxacin 1ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 K.p atcc700603 Imipenen 1ug 852nm 40mw 400mw tune43.08 001
......@@ -3,7 +3,7 @@
import torch
from torch import nn
from .mrnet_parts import MultiResBlock,ResPath,TransCompose
from .mrnet_parts import MultiResBlock, ResPath, TransCompose
class MultiUnet(nn.Module):
......@@ -40,8 +40,11 @@ class MultiUnet(nn.Module):
self.res9 = MultiResBlock(self.up9.outc*2, 32)
self.pool = nn.MaxPool2d(2)
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding = 1)
# self.outconv = nn.Sequential(
# nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
# nn.Sigmoid()
# )
self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
def forward(self, x):
x = self.inconv(x)
......
......@@ -4,6 +4,21 @@ import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchsnooper
def conv(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = 3, stride = 1,
padding = 1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace = True))
def shortcut(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(out_channels))
class MultiResBlock(nn.Module):
......@@ -13,26 +28,26 @@ class MultiResBlock(nn.Module):
self.W = U * alpha
self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
self.shortcut = shortcut(self.inc,self.outc)
def conv(self, in_channel, out_channel, kernel_size):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1),
nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
self.conv3 = conv(self.inc, int(self.W * 0.167))
self.conv5 = conv(int(self.W * 0.167), int(self.W * 0.333))
self.conv7 = conv(int(self.W * 0.333), int(self.W * 0.5))
self.norm = nn.BatchNorm2d(self.outc)
self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))
#@torchsnooper.snoop()
def forward(self, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
shortcut = self.shortcut(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5)
result = torch.cat([conv3, conv5, conv7], 1)
result = nn.BatchNorm2d(self.outc)(result)
conv3 = self.conv3(x)
conv5 = self.conv5(conv3)
conv7 = self.conv7(conv5)
result = torch.add(result, shortcut)
result = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))(result)
return result
comb = torch.cat([conv3, conv5, conv7], 1)
result = self.norm(comb)
return self.seq(torch.add(result, shortcut))
class TransCompose(nn.Module):
......@@ -40,12 +55,12 @@ class TransCompose(nn.Module):
super().__init__()
self.inc = in_channels
self.outc = out_channels
self.proc = nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc))
def forward(self, x):
return nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc)
)(x)
return self.proc(x)
class ResPath(nn.Module):
......@@ -55,20 +70,16 @@ class ResPath(nn.Module):
self.inc = in_channels
self.outc = out_channels
def unit(self, in_channels, out_channels, x):
shortcut = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv = nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
self.shortcut1 = shortcut(self.inc,self.outc)
self.conv1 = conv(self.inc,self.outc)
self.shortcut = shortcut(self.outc, self.outc)
self.conv = conv(self.outc, self.outc)
result = torch.add(conv, shortcut)
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))
def forward(self, x):
x = self.unit(self.inc, self.outc, x)
x = self.seq(torch.add(self.conv1(x), self.shortcut1(x)))
for i in range(self.length - 1):
x = self.unit(self.outc, self.outc, x)
x = self.seq(torch.add(self.conv(x), self.shortcut(x)))
return x
......@@ -7,9 +7,10 @@ from tqdm import tqdm
import torch
import torch.nn as nn
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from utils.dataset import BasicDataset
from utils.dataset import BasicDataset,VOCSegmentation
from utils.eval import eval_net
......@@ -18,47 +19,61 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1):
# dataset = BasicDataset(dir_img, dir_mask)
# n_val = int(len(dataset) * val_percent)
# n_train = len(dataset) - n_val
# train, val = random_split(dataset, [n_train, n_val])
# train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
# val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
trans = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()
])
trainset = VOCSegmentation('data', 'train', trans, trans)
evalset = VOCSegmentation('data', 'traineval', trans, trans)
n_train = len(trainset)
n_val = len(evalset)
train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCEWithLogitsLoss()
for epoch in tqdm(range(epochs)):
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
for imgs, true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
optimizer.zero_grad()
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.update(imgs.shape[0])
val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score))
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info('Validation : {}'.format(val_score))
if epoch % 5 == 0:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
torch.save(net.state_dict(), 'MODEL.pth')
......@@ -63,8 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try:
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr,
img_scale = args.scale, val_percent = args.val / 100)
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, val_percent = args.val / 100)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from os.path import splitext
from os import listdir
import numpy as np
......@@ -7,13 +9,15 @@ from torch.utils.data import Dataset
import logging
from PIL import Image
import os
from torchvision.datasets.vision import VisionDataset
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
#assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
......@@ -24,8 +28,6 @@ class BasicDataset(Dataset):
@classmethod
def preprocess(cls, pil_img):
#newW, newH = 256,256 #int(scale * w), int(scale * h)
#assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img)
......@@ -59,3 +61,68 @@ class BasicDataset(Dataset):
mask = self.preprocess(mask)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
class VOCSegmentation(VisionDataset):
def __init__(self, root, image_set = 'train', transform = None, target_transform = None, transforms = None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
base_dir = 'voc'
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it')
split_f = os.path.join(voc_root, image_set.rstrip('\n') + '.txt')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
@classmethod
def preprocess(cls, pil_img):
pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis = 2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('L')
target = Image.open(self.masks[index]).convert('L')
pim = target.load()
for i in range(200):
for j in range(200):
pim[i, j] = 1 if pim[i,j] >0 else 0
if self.transforms is not None:
img, target = self.transforms(img, target)
# img = self.preprocess(img)
# target = self.preprocess(target)
return img, target
#return {'image':torch.from_numpy(np.asarray(img)), 'mask':torch.from_numpy(np.asarray(target))}
def __len__(self):
return len(self.images)
......@@ -4,9 +4,10 @@ from torchvision import transforms
from utils.dataset import BasicDataset
def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
def predict_img(net, full_img, device, out_threshold = 0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = torch.from_numpy(BasicDataset.preprocess(full_img))
img = img.unsqueeze(0)
img = img.to(device = device, dtype = torch.float32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment