Commit ca83308e by 王肇一

mrnet VOC dataset without data Augment

parent 1fc01530
...@@ -52,11 +52,14 @@ def main(): ...@@ -52,11 +52,14 @@ def main():
f.writelines('\n'.join(class_names)) f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file) print('Saved class_names:', out_class_names_file)
name_base = 1
for path in os.listdir(args.input_dir): for path in os.listdir(args.input_dir):
for label_file in glob.glob(osp.join(args.input_dir + '/' + path, '*.json')): for label_file in glob.glob(osp.join(args.input_dir + '/' + path, '*.json')):
print('Generating dataset from:', label_file) print('Generating dataset from:', label_file)
with open(label_file) as f: with open(label_file) as f:
base = osp.splitext(osp.basename(label_file))[0] # base = osp.splitext(osp.basename(label_file))[0]
base = str(name_base)
name_base+=1
out_img_file = osp.join(args.output_dir, 'JPEGImages', base + '.jpg') out_img_file = osp.join(args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join(args.output_dir, 'SegmentationClass', base + '.npy') out_lbl_file = osp.join(args.output_dir, 'SegmentationClass', base + '.npy')
out_png_file = osp.join(args.output_dir, 'SegmentationClassPNG', base + '.png') out_png_file = osp.join(args.output_dir, 'SegmentationClassPNG', base + '.png')
...@@ -87,7 +90,7 @@ def lblsave(filename, lbl): ...@@ -87,7 +90,7 @@ def lblsave(filename, lbl):
# Assume label ranses [-1, 254] for int32, # Assume label ranses [-1, 254] for int32,
# and [0, 255] for uint8 as VOC. # and [0, 255] for uint8 as VOC.
if lbl.min() >= -1 and lbl.max() < 255: if lbl.min() >= -1 and lbl.max() < 255:
lbl = np.array([1 if lbl[x, y] > 0 else 0 for x in range(200) for y in range(200)]).reshape([200, 200]) # lbl = np.array([1 if lbl[x,y]>0 else 0 for x in range(200) for y in range(200)]).reshape([200,200])
lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode = 'P') lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode = 'P')
colormap = imgviz.label_colormap() colormap = imgviz.label_colormap()
lbl_pil.putpalette(colormap.flatten()) lbl_pil.putpalette(colormap.flatten())
......
No preview for this file type
...@@ -14,24 +14,11 @@ from utils.dataset import BasicDataset,VOCSegmentation ...@@ -14,24 +14,11 @@ from utils.dataset import BasicDataset,VOCSegmentation
from utils.eval import eval_net from utils.eval import eval_net
dir_img = 'data/train_imgs/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoint/' dir_checkpoint = 'checkpoint/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1): def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
# dataset = BasicDataset(dir_img, dir_mask) trans = transforms.Compose([transforms.Resize(256),transforms.ToTensor()])
# n_val = int(len(dataset) * val_percent)
# n_train = len(dataset) - n_val
# train, val = random_split(dataset, [n_train, n_val])
# train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
# val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
trans = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor()
])
trainset = VOCSegmentation('data', 'train', trans, trans) trainset = VOCSegmentation('data', 'train', trans, trans)
evalset = VOCSegmentation('data', 'traineval', trans, trans) evalset = VOCSegmentation('data', 'traineval', trans, trans)
n_train = len(trainset) n_train = len(trainset)
...@@ -47,9 +34,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -47,9 +34,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss = 0 epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for imgs, true_masks in train_loader: for imgs, true_masks in train_loader:
# imgs = batch['image']
# true_masks = batch['mask']
imgs = imgs.to(device = device, dtype = torch.float32) imgs = imgs.to(device = device, dtype = torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device = device, dtype = mask_type) true_masks = true_masks.to(device = device, dtype = mask_type)
......
...@@ -63,7 +63,7 @@ if __name__ == '__main__': ...@@ -63,7 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True # cudnn.benchmark = True
try: try:
mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr, val_percent = args.val / 100) mrnet.train_net(net = net, device = device, epochs = args.epochs, batch_size = args.batchsize, lr = args.lr)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt') logging.info('Saved interrupt')
......
...@@ -8,12 +8,13 @@ import sys ...@@ -8,12 +8,13 @@ import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import optim from torch import optim
from torchvision import transforms
from tqdm import tqdm from tqdm import tqdm
from utils.eval import eval_net from utils.eval import eval_net
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset from utils.dataset import BasicDataset,VOCSegmentation
from torch.utils.data import DataLoader, random_split from torch.utils.data import DataLoader, random_split
dir_img = 'data/train_imgs/' dir_img = 'data/train_imgs/'
...@@ -21,15 +22,16 @@ dir_mask = 'data/train_masks/' ...@@ -21,15 +22,16 @@ dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/' dir_checkpoint = 'checkpoints/'
def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0.1, save_cp = True, img_scale = 0.5): def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True):
dataset = BasicDataset(dir_img, dir_mask, img_scale) trans = transforms.Compose([transforms.Resize(256), transforms.ToTensor()])
n_val = int(len(dataset) * val_percent) trainset = VOCSegmentation('data', 'train', trans, trans)
n_train = len(dataset) - n_val evalset = VOCSegmentation('data', 'traineval', trans, trans)
train, val = random_split(dataset, [n_train, n_val]) n_train = len(trainset)
train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True) n_val = len(evalset)
val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True) train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
val_loader = DataLoader(evalset, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') writer = SummaryWriter(comment = f'LR_{lr}_BS_{batch_size}')
global_step = 0 global_step = 0
logging.info(f'''Starting training: logging.info(f'''Starting training:
...@@ -40,7 +42,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -40,7 +42,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
Validation size: {n_val} Validation size: {n_val}
Checkpoints: {save_cp} Checkpoints: {save_cp}
Device: {device.type} Device: {device.type}
Images scaling: {img_scale}
''') ''')
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8) # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
...@@ -56,9 +57,9 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -56,9 +57,9 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss = 0 epoch_loss = 0
with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar: with tqdm(total = n_train, desc = f'Epoch {epoch + 1}/{epochs}', unit = 'img') as pbar:
for batch in train_loader: for imgs,true_masks in train_loader:
imgs = batch['image'] # imgs = batch['image']
true_masks = batch['mask'] # true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \ # assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \ # f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ # f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
...@@ -95,7 +96,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0 ...@@ -95,7 +96,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
writer.add_images('masks/true', true_masks, global_step) writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step) writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp and epoch % 4 == 0: if save_cp and epoch % 5 == 0:
try: try:
os.mkdir(dir_checkpoint) os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory') logging.info('Created checkpoint directory')
......
...@@ -8,6 +8,9 @@ import torch ...@@ -8,6 +8,9 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import logging import logging
from PIL import Image from PIL import Image
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import os import os
from torchvision.datasets.vision import VisionDataset from torchvision.datasets.vision import VisionDataset
...@@ -73,7 +76,7 @@ class VOCSegmentation(VisionDataset): ...@@ -73,7 +76,7 @@ class VOCSegmentation(VisionDataset):
mask_dir = os.path.join(voc_root, 'SegmentationClassPNG') mask_dir = os.path.join(voc_root, 'SegmentationClassPNG')
if not os.path.isdir(voc_root): if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') raise RuntimeError('Dataset not found or corrupted.')
split_f = os.path.join(voc_root, image_set.rstrip('\n') + '.txt') split_f = os.path.join(voc_root, image_set.rstrip('\n') + '.txt')
...@@ -84,45 +87,26 @@ class VOCSegmentation(VisionDataset): ...@@ -84,45 +87,26 @@ class VOCSegmentation(VisionDataset):
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks)) assert (len(self.images) == len(self.masks))
@classmethod self.seq = iaa.Sequential([iaa.SomeOf((0, 5), [iaa.Noop(), iaa.Fliplr(0.5),
def preprocess(cls, pil_img): iaa.Sometimes(0.25, iaa.Dropout(p = (0, 0.1))), iaa.Affine(rotate = (-45, 45)),
pil_img = pil_img.resize((256, 256)) iaa.ElasticTransformation(alpha = 50, sigma = 5)
], random_order = True)])
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis = 2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, index): def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('L') img = Image.open(self.images[index]).convert('L')
target = Image.open(self.masks[index]).convert('L') target = Image.open(self.masks[index]).convert('L')
pim = target.load() pim = target.load()
for i in range(200): for i in range(200):
for j in range(200): for j in range(200):
pim[i, j] = 1 if pim[i,j] >0 else 0 pim[i, j] = 1 if pim[i, j] > 0 else 0
# img, target = self.seq(image=np.array(img), segmentation_maps = np.array(target))
if self.transforms is not None: if self.transforms is not None:
img, target = self.transforms(img, target) img, target = self.transforms(img, target)
# img = self.preprocess(img)
# target = self.preprocess(target)
return img, target return img, target
#return {'image':torch.from_numpy(np.asarray(img)), 'mask':torch.from_numpy(np.asarray(target))}
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment