Commit 1fc01530 by 王肇一

mrnet with VOC type dataset

parent 2a781f86
__pycache__/ __pycache__/
.idea/* .idea/*
data/imgs/* data
data/masks/*
data/output/*
data/train_imgs/*
data/train_masks/*
data/train_imgs_32/*
.ipynb_checkpoints/ .ipynb_checkpoints/
runs runs
checkpoint
\ No newline at end of file
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import re import re
from unet import UNet from unet import UNet
from mrnet import MultiUnet
from utils.predict import predict_img from utils.predict import predict_img
from resCalc import save_img, get_subarea_info, save_img_mask from resCalc import save_img, get_subarea_info, save_img_mask
...@@ -29,7 +30,7 @@ def step_1(net, args, device, list, position): ...@@ -29,7 +30,7 @@ def step_1(net, args, device, list, position):
for fn in tqdm(list, position = position): for fn in tqdm(list, position = position):
logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1])) logging.info("\nPredicting image {} ...".format(fn[0] + '/' + fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1]) img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
mask = predict_img(net = net, full_img = img, scale_factor = args.scale, out_threshold = args.mask_threshold, mask = predict_img(net = net, full_img = img, out_threshold = args.mask_threshold,
device = device) device = device)
result = (mask * 255).astype(np.uint8) result = (mask * 255).astype(np.uint8)
...@@ -100,7 +101,8 @@ def cli_main(): ...@@ -100,7 +101,8 @@ def cli_main():
seperate_path = divide_list(path, args.process) seperate_path = divide_list(path, args.process)
if args.step == 1: if args.step == 1:
net = UNet(n_channels = 1, n_classes = 1) # net = UNet(n_channels = 1, n_classes = 1)
net = MultiUnet(n_channels = 1,n_classes = 1)
logging.info("Loading model {}".format(args.module)) logging.info("Loading model {}".format(args.module))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
......
...@@ -15,15 +15,11 @@ import labelme ...@@ -15,15 +15,11 @@ import labelme
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter)
formatter_class=argparse.ArgumentDefaultsHelpFormatter parser.add_argument('input_dir', help = 'input annotated directory')
) parser.add_argument('output_dir', help = 'output dataset directory')
parser.add_argument('input_dir', help='input annotated directory') parser.add_argument('--labels', help = 'labels file', required = True)
parser.add_argument('output_dir', help='output dataset directory') parser.add_argument('--noviz', help = 'no visualization', action = 'store_true')
parser.add_argument('--labels', help='labels file', required=True)
parser.add_argument(
'--noviz', help='no visualization', action='store_true'
)
args = parser.parse_args() args = parser.parse_args()
if osp.exists(args.output_dir): if osp.exists(args.output_dir):
...@@ -34,9 +30,7 @@ def main(): ...@@ -34,9 +30,7 @@ def main():
os.makedirs(osp.join(args.output_dir, 'SegmentationClass')) os.makedirs(osp.join(args.output_dir, 'SegmentationClass'))
os.makedirs(osp.join(args.output_dir, 'SegmentationClassPNG')) os.makedirs(osp.join(args.output_dir, 'SegmentationClassPNG'))
if not args.noviz: if not args.noviz:
os.makedirs( os.makedirs(osp.join(args.output_dir, 'SegmentationClassVisualization'))
osp.join(args.output_dir, 'SegmentationClassVisualization')
)
print('Creating dataset:', args.output_dir) print('Creating dataset:', args.output_dir)
class_names = [] class_names = []
...@@ -58,47 +52,49 @@ def main(): ...@@ -58,47 +52,49 @@ def main():
f.writelines('\n'.join(class_names)) f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file) print('Saved class_names:', out_class_names_file)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')): for path in os.listdir(args.input_dir):
print('Generating dataset from:', label_file) for label_file in glob.glob(osp.join(args.input_dir + '/' + path, '*.json')):
with open(label_file) as f: print('Generating dataset from:', label_file)
base = osp.splitext(osp.basename(label_file))[0] with open(label_file) as f:
out_img_file = osp.join( base = osp.splitext(osp.basename(label_file))[0]
args.output_dir, 'JPEGImages', base + '.jpg') out_img_file = osp.join(args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join( out_lbl_file = osp.join(args.output_dir, 'SegmentationClass', base + '.npy')
args.output_dir, 'SegmentationClass', base + '.npy') out_png_file = osp.join(args.output_dir, 'SegmentationClassPNG', base + '.png')
out_png_file = osp.join( if not args.noviz:
args.output_dir, 'SegmentationClassPNG', base + '.png') out_viz_file = osp.join(args.output_dir, 'SegmentationClassVisualization', base + '.jpg', )
if not args.noviz:
out_viz_file = osp.join( data = json.load(f)
args.output_dir,
'SegmentationClassVisualization', img_file = osp.join(osp.dirname(label_file), data['imagePath'])
base + '.jpg', img = np.asarray(PIL.Image.open(img_file))
) PIL.Image.fromarray(img).save(out_img_file)
data = json.load(f) lbl = labelme.utils.shapes_to_label(img_shape = img.shape, shapes = data['shapes'],
label_name_to_value = class_name_to_id, )
img_file = osp.join(osp.dirname(label_file), data['imagePath']) lblsave(out_png_file, lbl)
img = np.asarray(PIL.Image.open(img_file))
PIL.Image.fromarray(img).save(out_img_file) np.save(out_lbl_file, lbl)
lbl = labelme.utils.shapes_to_label( if not args.noviz:
img_shape=img.shape, viz = imgviz.label2rgb(label = lbl, img = img, font_size = 15, label_names = class_names,
shapes=data['shapes'], loc = 'rb', )
label_name_to_value=class_name_to_id, imgviz.io.imsave(out_viz_file, viz)
)
labelme.utils.lblsave(out_png_file, lbl)
def lblsave(filename, lbl):
np.save(out_lbl_file, lbl) if osp.splitext(filename)[1] != '.png':
filename += '.png'
if not args.noviz: # Assume label ranses [-1, 254] for int32,
viz = imgviz.label2rgb( # and [0, 255] for uint8 as VOC.
label=lbl, if lbl.min() >= -1 and lbl.max() < 255:
img=img, lbl = np.array([1 if lbl[x, y] > 0 else 0 for x in range(200) for y in range(200)]).reshape([200, 200])
font_size=15, lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode = 'P')
label_names=class_names, colormap = imgviz.label_colormap()
loc='rb', lbl_pil.putpalette(colormap.flatten())
) lbl_pil.save(filename)
imgviz.io.imsave(out_viz_file, viz) else:
raise ValueError('[%s] Cannot save the pixel-wise class label as PNG. '
'Please consider using the .npy format.' % filename)
if __name__ == '__main__': if __name__ == '__main__':
......
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 002
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 002
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Tobramycin 4ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 1ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 S.a atcc 29213 Levofloxacin 0.5ug 852nm 40mw 400mw tune 43.08 003
Step size0Dwell time50 S.a atcc 29213 Linezolid 0.25ug 852nm 40mw 400mw tune 43.08 002
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 2ug 852nm 30mw 300mw tune 43.06 003
Step size0Dwell time50 Clinical Sample K. p Tobramycin 64ug 852nm 30mw 300mw tune 43.00 004
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 512ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 K.p atcc 700603 12.16 Ampicillin 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 256ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a 1h lb 852nm 30mw 300mw tune 43.03 005
Step size0Dwell time50 Clinical Sample K. p Ceftazidime 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 P.a atcc27853 Levofloxacin-11.29 4ug 852nm 30mw 300mw tune 43.06 001
Step size0Dwell time50 K.p atcc 700603 12.16 LB 40mw 400mw tune 43.08 002
Step size0Dwell time50 E.coil atcc25922 Ceftazime 1ug 852nm 40mw 400mw tune43.08 60x oil obj 003
Step size0Dwell time50 Clinical Sample K. p Ampicillcin 32ug 852nm 30mw 300mw tune 43.00 003
Step size0Dwell time50 S.a atcc 29213 Ampicillin 1ug 852nm 40mw 400mw tune 43.08 004
Step size0Dwell time50 E.coil atcc25922 Imipenen 1ug 852nm 40mw 400mw tune43.08 60x oil obj 001
Step size0Dwell time50 K.p atcc 700603 12.16 Ceftazidime 64ug 40mw 400mw tune 43.08 003
Step size0Dwell time50 K.p atcc 700603 12.16 Levofloxacin 1ug 40mw 400mw tune 43.08 001
Step size0Dwell time50 K.p atcc700603 Imipenen 1ug 852nm 40mw 400mw tune43.08 001
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import torch import torch
from torch import nn from torch import nn
from .mrnet_parts import MultiResBlock,ResPath,TransCompose from .mrnet_parts import MultiResBlock, ResPath, TransCompose
class MultiUnet(nn.Module): class MultiUnet(nn.Module):
...@@ -40,8 +40,11 @@ class MultiUnet(nn.Module): ...@@ -40,8 +40,11 @@ class MultiUnet(nn.Module):
self.res9 = MultiResBlock(self.up9.outc*2, 32) self.res9 = MultiResBlock(self.up9.outc*2, 32)
self.pool = nn.MaxPool2d(2) self.pool = nn.MaxPool2d(2)
# self.outconv = nn.Sequential(
self.outconv = nn.Conv2d(self.res9.outc, n_classes, kernel_size = 3, padding = 1) # nn.Conv2d(self.res9.outc, n_classes, kernel_size = 1),
# nn.Sigmoid()
# )
self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
def forward(self, x): def forward(self, x):
x = self.inconv(x) x = self.inconv(x)
......
...@@ -4,6 +4,21 @@ import torch ...@@ -4,6 +4,21 @@ import torch
import torchvision import torchvision
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchsnooper
def conv(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = 3, stride = 1,
padding = 1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace = True))
def shortcut(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(out_channels))
class MultiResBlock(nn.Module): class MultiResBlock(nn.Module):
...@@ -13,26 +28,26 @@ class MultiResBlock(nn.Module): ...@@ -13,26 +28,26 @@ class MultiResBlock(nn.Module):
self.W = U * alpha self.W = U * alpha
self.inc = in_channels self.inc = in_channels
self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5) self.outc = int(self.W * 0.167) + int(self.W * 0.333) + int(self.W * 0.5)
self.shortcut = shortcut(self.inc,self.outc)
def conv(self, in_channel, out_channel, kernel_size): self.conv3 = conv(self.inc, int(self.W * 0.167))
return nn.Sequential( self.conv5 = conv(int(self.W * 0.167), int(self.W * 0.333))
nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = kernel_size, stride = 1,padding = 1), self.conv7 = conv(int(self.W * 0.333), int(self.W * 0.5))
nn.BatchNorm2d(out_channel), nn.ReLU(inplace = True))
self.norm = nn.BatchNorm2d(self.outc)
self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))
#@torchsnooper.snoop()
def forward(self, x): def forward(self, x):
shortcut = nn.Sequential( shortcut = self.shortcut(x)
nn.Conv2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x)
conv3 = self.conv(self.inc, int(self.W * 0.167), 3)(x) conv3 = self.conv3(x)
conv5 = self.conv(int(self.W * 0.167), int(self.W * 0.333), 3)(conv3) conv5 = self.conv5(conv3)
conv7 = self.conv(int(self.W * 0.333), int(self.W * 0.5), 3)(conv5) conv7 = self.conv7(conv5)
result = torch.cat([conv3, conv5, conv7], 1)
result = nn.BatchNorm2d(self.outc)(result)
result = torch.add(result, shortcut) comb = torch.cat([conv3, conv5, conv7], 1)
result = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(self.outc))(result) result = self.norm(comb)
return result return self.seq(torch.add(result, shortcut))
class TransCompose(nn.Module): class TransCompose(nn.Module):
...@@ -40,12 +55,12 @@ class TransCompose(nn.Module): ...@@ -40,12 +55,12 @@ class TransCompose(nn.Module):
super().__init__() super().__init__()
self.inc = in_channels self.inc = in_channels
self.outc = out_channels self.outc = out_channels
self.proc = nn.Sequential(
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc))
def forward(self, x): def forward(self, x):
return nn.Sequential( return self.proc(x)
nn.ConvTranspose2d(in_channels = self.inc, out_channels = self.outc, kernel_size = 2, stride = 2),
nn.BatchNorm2d(self.outc)
)(x)
class ResPath(nn.Module): class ResPath(nn.Module):
...@@ -55,20 +70,16 @@ class ResPath(nn.Module): ...@@ -55,20 +70,16 @@ class ResPath(nn.Module):
self.inc = in_channels self.inc = in_channels
self.outc = out_channels self.outc = out_channels
def unit(self, in_channels, out_channels, x): self.shortcut1 = shortcut(self.inc,self.outc)
shortcut = nn.Sequential( self.conv1 = conv(self.inc,self.outc)
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 1, stride = 1),
nn.BatchNorm2d(self.outc))(x) self.shortcut = shortcut(self.outc, self.outc)
conv = nn.Sequential( self.conv = conv(self.outc, self.outc)
nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace = True))(x)
result = torch.add(conv, shortcut) self.seq = nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))
return nn.Sequential(nn.ReLU(inplace = True), nn.BatchNorm2d(out_channels))(result)
def forward(self, x): def forward(self, x):
x = self.unit(self.inc, self.outc, x) x = self.seq(torch.add(self.conv1(x), self.shortcut1(x)))
for i in range(self.length - 1): for i in range(self.length - 1):
x = self.unit(self.outc, self.outc, x) x = self.seq(torch.add(self.conv(x), self.shortcut(x)))
return x return x
...@@ -7,9 +7,10 @@ from tqdm import tqdm ...@@ -7,9 +7,10 @@ from tqdm import tqdm
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
from utils.dataset import BasicDataset from utils.dataset import BasicDataset,VOCSegmentation
from utils.eval import eval_net from utils.eval import eval_net
...@@ -18,47 +19,61 @@ dir_mask = 'data/train_masks/' ...@@ -18,47 +19,61 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint/' dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, img_scale = 0.5): def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1):
dataset = BasicDataset(dir_img, dir_mask, img_scale) # dataset = BasicDataset(dir_img, dir_mask)
n_val = int(len(dataset) * val_percent) # n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val # n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val]) # train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True) # train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) # val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
trans = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()
])
trainset = VOCSegmentation('data', 'train', trans, trans)
evalset = VOCSegmentation('data', 'traineval', trans, trans)
n_train = len(trainset)
n_val = len(evalset)
train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
optimizer = optim.Adam(net.parameters(), lr = lr) optimizer = optim.Adam(net.parameters(), lr = lr)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
for epoch in tqdm(range(epochs)): for epoch in range(epochs):
net.train() net.train()
epoch_loss = 0 epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader: for imgs, true_masks in train_loader:
imgs = batch['image'] # imgs = batch['image']
true_masks = batch['mask'] # true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
optimizer.zero_grad()
masks_pred = net(imgs) masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks) loss = criterion(masks_pred, true_masks)
epoch_loss += loss.item() epoch_loss += loss.item()
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
val_score = eval_net(net, val_loader, device, n_val) val_score = eval_net(net, val_loader, device, n_val)
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation : {}'.format(val_score))
try: if epoch % 5 == 0:
os.mkdir(dir_checkpoint) try:
logging.info('Created checkpoint directory') os.mkdir(dir_checkpoint)
except OSError: logging.info('Created checkpoint directory')
pass except OSError:
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth') pass
torch.save(net.state_dict(), dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
torch.save(net.state_dict(), 'MODEL.pth') torch.save(net.state_dict(), 'MODEL.pth')
...@@ -63,8 +63,7 @@ if __name__ == '__main__': ...@@ -63,8 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, val_percent = args.val / 100)
img_scale = args.scale, val_percent = args.val / 100)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt') logging.info('Saved interrupt')
......
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
#!/usr/bin/env python
# -*- coding:utf-8 -*-
from os.path import splitext from os.path import splitext
from os import listdir from os import listdir
import numpy as np import numpy as np
...@@ -7,13 +9,15 @@ from torch.utils.data import Dataset ...@@ -7,13 +9,15 @@ from torch.utils.data import Dataset
import logging import logging
from PIL import Image from PIL import Image
import os
from torchvision.datasets.vision import VisionDataset
class BasicDataset(Dataset): class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1): def __init__(self, imgs_dir, masks_dir, scale=1):
self.imgs_dir = imgs_dir self.imgs_dir = imgs_dir
self.masks_dir = masks_dir self.masks_dir = masks_dir
self.scale = scale self.scale = scale
#assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir) self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')] if not file.startswith('.')]
...@@ -24,8 +28,6 @@ class BasicDataset(Dataset): ...@@ -24,8 +28,6 @@ class BasicDataset(Dataset):
@classmethod @classmethod
def preprocess(cls, pil_img): def preprocess(cls, pil_img):
#newW, newH = 256,256 #int(scale * w), int(scale * h)
#assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((256, 256)) pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img) img_nd = np.array(pil_img)
...@@ -59,3 +61,68 @@ class BasicDataset(Dataset): ...@@ -59,3 +61,68 @@ class BasicDataset(Dataset):
mask = self.preprocess(mask) mask = self.preprocess(mask)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)} return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
class VOCSegmentation(VisionDataset):
def __init__(self, root, image_set = 'train', transform = None, target_transform = None, transforms = None):
super(VOCSegmentation, self).__init__(root, transforms, transform, target_transform)
base_dir = 'voc'
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it')
split_f = os.path.join(voc_root, image_set.rstrip('\n') + '.txt')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
@classmethod
def preprocess(cls, pil_img):
pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis = 2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('L')
target = Image.open(self.masks[index]).convert('L')
pim = target.load()
for i in range(200):
for j in range(200):
pim[i, j] = 1 if pim[i,j] >0 else 0
if self.transforms is not None:
img, target = self.transforms(img, target)
# img = self.preprocess(img)
# target = self.preprocess(target)
return img, target
#return {'image':torch.from_numpy(np.asarray(img)), 'mask':torch.from_numpy(np.asarray(target))}
def __len__(self):
return len(self.images)
...@@ -4,9 +4,10 @@ from torchvision import transforms ...@@ -4,9 +4,10 @@ from torchvision import transforms
from utils.dataset import BasicDataset from utils.dataset import BasicDataset
def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
def predict_img(net, full_img, device, out_threshold = 0.5):
net.eval() net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor)) img = torch.from_numpy(BasicDataset.preprocess(full_img))
img = img.unsqueeze(0) img = img.unsqueeze(0)
img = img.to(device = device, dtype = torch.float32) img = img.to(device = device, dtype = torch.float32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment