Commit ca83308e by 王肇一

mrnet VOC dataset without data Augment

parent 1fc01530
......@@ -52,11 +52,14 @@ def main():
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
name_base = 1
for path in os.listdir(args.input_dir):
for label_file in glob.glob(osp.join(args.input_dir + '/' + path, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
base = osp.splitext(osp.basename(label_file))[0]
# base = osp.splitext(osp.basename(label_file))[0]
base = str(name_base)
name_base+=1
out_img_file = osp.join(args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join(args.output_dir, 'SegmentationClass', base + '.npy')
out_png_file = osp.join(args.output_dir, 'SegmentationClassPNG', base + '.png')
......@@ -87,7 +90,7 @@ def lblsave(filename, lbl):
# Assume label ranses [-1, 254] for int32,
# and [0, 255] for uint8 as VOC.
if lbl.min() >= -1 and lbl.max() < 255:
lbl = np.array([1 if lbl[x, y] > 0 else 0 for x in range(200) for y in range(200)]).reshape([200, 200])
# lbl = np.array([1 if lbl[x,y]>0 else 0 for x in range(200) for y in range(200)]).reshape([200,200])
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode = 'P')
colormap = imgviz.label_colormap()
lbl_pil.putpalette(colormap.flatten())
......
No preview for this file type
......@@ -14,24 +14,11 @@ from utils.dataset import BasicDataset,VOCSegmentation
from utils.eval import eval_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1):
# dataset = BasicDataset(dir_img, dir_mask)
# n_val = int(len(dataset) * val_percent)
# n_train = len(dataset) - n_val
# train, val = random_split(dataset, [n_train, n_val])
# train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
# val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
trans = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()
])
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
trans = transforms.Compose([transforms.Resize(256),transforms.ToTensor()])
trainset = VOCSegmentation('data', 'train', trans, trans)
evalset = VOCSegmentation('data', 'traineval', trans, trans)
n_train = len(trainset)
......@@ -47,9 +34,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs, true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type)
......
......@@ -63,7 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try:
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, val_percent = args.val / 100)
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
......
......@@ -8,12 +8,13 @@ import sys
import torch
import torch.nn as nn
from torch import optim
from torchvision import transforms
from tqdm import tqdm
from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from utils.dataset import BasicDataset,VOCSegmentation
from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/'
......@@ -21,15 +22,16 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True):
trans = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
trainset = VOCSegmentation('data', 'train', trans, trans)
evalset = VOCSegmentation('data', 'traineval', trans, trans)
n_train = len(trainset)
n_val = len(evalset)
train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}')
global_step = 0
logging.info(f'''Starting training:
......@@ -40,7 +42,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
''')
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
......@@ -56,9 +57,9 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
for imgs,true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
......@@ -95,7 +96,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 4 == 0:
if save_cp and epoch % 5 == 0:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
......
......@@ -8,6 +8,9 @@ import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import os
from torchvision.datasets.vision import VisionDataset
......@@ -73,7 +76,7 @@ class VOCSegmentation(VisionDataset):
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it')
raise RuntimeError('Dataset not found or corrupted.')
split_f = os.path.join(voc_root, image_set.rstrip('\n') + '.txt')
......@@ -84,45 +87,26 @@ class VOCSegmentation(VisionDataset):
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
@classmethod
def preprocess(cls, pil_img):
pil_img = pil_img.resize((256, 256))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis = 2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
self.seq = iaa.Sequential([iaa.SomeOf((0, 5), [iaa.Noop(), iaa.Fliplr(0.5),
iaa.Sometimes(0.25, iaa.Dropout(p = (0, 0.1))), iaa.Affine(rotate = (-45, 45)),
iaa.ElasticTransformation(alpha = 50, sigma = 5)
], random_order = True)])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('L')
target = Image.open(self.masks[index]).convert('L')
pim = target.load()
for i in range(200):
for j in range(200):
pim[i, j] = 1 if pim[i,j] >0 else 0
pim[i, j] = 1 if pim[i, j] > 0 else 0
# img, target = self.seq(image=np.array(img), segmentation_maps = np.array(target))
if self.transforms is not None:
img, target = self.transforms(img, target)
# img = self.preprocess(img)
# target = self.preprocess(target)
return img, target
#return {'image':torch.from_numpy(np.asarray(img)), 'mask':torch.from_numpy(np.asarray(target))}
def __len__(self):
return len(self.images)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment