Commit af74ed93 by 王肇一

Unet inside

parent 863c51ff
from __future__ import print_function
import argparse
import glob
import json
import os
import os.path as osp
import sys
import imgviz
import numpy as np
import PIL.Image
import labelme
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('input_dir', help='input annotated directory')
parser.add_argument('output_dir', help='output dataset directory')
parser.add_argument('--labels', help='labels file', required=True)
parser.add_argument(
'--noviz', help='no visualization', action='store_true'
)
args = parser.parse_args()
if osp.exists(args.output_dir):
print('Output directory already exists:', args.output_dir)
sys.exit(1)
os.makedirs(args.output_dir)
os.makedirs(osp.join(args.output_dir, 'JPEGImages'))
os.makedirs(osp.join(args.output_dir, 'SegmentationClass'))
os.makedirs(osp.join(args.output_dir, 'SegmentationClassPNG'))
if not args.noviz:
os.makedirs(
osp.join(args.output_dir, 'SegmentationClassVisualization')
)
print('Creating dataset:', args.output_dir)
class_names = []
class_name_to_id = {}
for i, line in enumerate(open(args.labels).readlines()):
class_id = i - 1 # starts with -1
class_name = line.strip()
class_name_to_id[class_name] = class_id
if class_id == -1:
assert class_name == '__ignore__'
continue
elif class_id == 0:
assert class_name == '_background_'
class_names.append(class_name)
class_names = tuple(class_names)
print('class_names:', class_names)
out_class_names_file = osp.join(args.output_dir, 'class_names.txt')
with open(out_class_names_file, 'w') as f:
f.writelines('\n'.join(class_names))
print('Saved class_names:', out_class_names_file)
for label_file in glob.glob(osp.join(args.input_dir, '*.json')):
print('Generating dataset from:', label_file)
with open(label_file) as f:
base = osp.splitext(osp.basename(label_file))[0]
out_img_file = osp.join(
args.output_dir, 'JPEGImages', base + '.jpg')
out_lbl_file = osp.join(
args.output_dir, 'SegmentationClass', base + '.npy')
out_png_file = osp.join(
args.output_dir, 'SegmentationClassPNG', base + '.png')
if not args.noviz:
out_viz_file = osp.join(
args.output_dir,
'SegmentationClassVisualization',
base + '.jpg',
)
data = json.load(f)
img_file = osp.join(osp.dirname(label_file), data['imagePath'])
img = np.asarray(PIL.Image.open(img_file))
PIL.Image.fromarray(img).save(out_img_file)
lbl = labelme.utils.shapes_to_label(
img_shape=img.shape,
shapes=data['shapes'],
label_name_to_value=class_name_to_id,
)
labelme.utils.lblsave(out_png_file, lbl)
np.save(out_lbl_file, lbl)
if not args.noviz:
viz = imgviz.label2rgb(
label=lbl,
img=img,
font_size=15,
label_names=class_names,
loc='rb',
)
imgviz.io.imsave(out_viz_file, viz)
if __name__ == '__main__':
main()
\ No newline at end of file
import torch
from torch.autograd import Function
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = (2 * self.inter.float() + eps) / self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union)
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torchvision
from torchvision import transforms, models, datasets
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = 3)
self.conv2 = nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = 3)
self.pool1 = nn.MaxPool2d(kernel_size = 2)
self.conv3 = nn.Conv2d(in_channels = 64,out_channels = 128,kernel_size = 3)
self.conv4 = nn.Conv2d(in_channels = 128,out_channels = 128,kernel_size = 3)
self.pool2 = nn.MaxPool2d(kernel_size = 2)
self.conv5 = nn.Conv2d(in_channels = 128,out_channels = 256,kernel_size = 3)
self.conv6 = nn.Conv2d(in_channels = 256,out_channels = 256,kernel_size = 3)
self.pool3 = nn.MaxPool2d(kernel_size = 2)
self.conv7 = nn.Conv2d(in_channels = 256,out_channels = 512,kernel_size = 3)
self.conv8 = nn.Conv2d(in_channels = 512,out_channels = 512,kernel_size = 3)
self.pool4 = nn.MaxPool2d(kernel_size = 2)
self.conv9 = nn.Conv2d(in_channels = 512,out_channels = 1024,kernel_size = 3)
self.conv10 = nn.Conv2d(in_channels = 1024,out_channels = 1024,kernel_size = 3)
self.up1 = nn.C
\ No newline at end of file
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_loss import dice_coeff
def eval_net(net, loader, device, n_val):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
tot = 0
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
for batch in loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
mask_pred = net(imgs)
for true_mask, pred in zip(true_masks, mask_pred):
pred = (pred > 0.5).float()
if net.n_classes > 1:
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
else:
tot += dice_coeff(pred, true_mask.squeeze(dim=1)).item()
pbar.update(imgs.shape[0])
return tot / n_val
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
else:
probs = torch.sigmoid(output)
probs = probs.squeeze(0)
tf = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.ToTensor()
]
)
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='Filenames of ouput images')
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_true',
help="Do not save the output masks",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=1)
logging.info("Loading model {}".format(args.model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded !")
for i, fn in enumerate(in_files):
logging.info("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
if not args.no_save:
out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i]))
if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
plot_img_and_mask(img, mask)
""" Submit code specific to the kaggle challenge"""
import os
import torch
from PIL import Image
import numpy as np
from predict import predict_img
from unet import UNet
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of
# the original image) by setting those pixels to '0' explicitly.
# We do not expect these to be non-zero for an accurate mask,
# so this should not harm the score.
pixels[0] = 0
pixels[-1] = 0
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
runs[1::2] = runs[1::2] - runs[:-1:2]
return runs
def submit(net, gpu=False):
"""Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/'
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, gpu)
enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth'))
submit(net, True)
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.shape[2] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i+1].set_title(f'Output mask (class {i+1})')
ax[i+1].imshow(mask[:, :, i])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
plt.show()
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, i):
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + '*')
img_file = glob(self.imgs_dir + idx + '*')
assert len(mask_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
assert len(img_file) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file}'
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {'image': torch.from_numpy(img), 'mask': torch.from_numpy(mask)}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment