Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
王肇一
Im
Commits
ca83308e
Commit
ca83308e
authored
Feb 07, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mrnet VOC dataset without data Augment
parent
1fc01530
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
34 additions
and
62 deletions
+34
-62
json_to_voc.py
cli/json_to_voc.py
+5
-2
MODEL.pth
data/module/MODEL.pth
+0
-0
train.py
mrnet/train.py
+2
-18
train.py
train.py
+1
-1
train.py
unet/train.py
+15
-14
dataset.py
utils/dataset.py
+11
-27
No files found.
cli/json_to_voc.py
View file @
ca83308e
...
...
@@ -52,11 +52,14 @@ def main():
f
.
writelines
(
'
\n
'
.
join
(
class_names
))
print
(
'Saved class_names:'
,
out_class_names_file
)
name_base
=
1
for
path
in
os
.
listdir
(
args
.
input_dir
):
for
label_file
in
glob
.
glob
(
osp
.
join
(
args
.
input_dir
+
'/'
+
path
,
'*.json'
)):
print
(
'Generating dataset from:'
,
label_file
)
with
open
(
label_file
)
as
f
:
base
=
osp
.
splitext
(
osp
.
basename
(
label_file
))[
0
]
# base = osp.splitext(osp.basename(label_file))[0]
base
=
str
(
name_base
)
name_base
+=
1
out_img_file
=
osp
.
join
(
args
.
output_dir
,
'JPEGImages'
,
base
+
'.jpg'
)
out_lbl_file
=
osp
.
join
(
args
.
output_dir
,
'SegmentationClass'
,
base
+
'.npy'
)
out_png_file
=
osp
.
join
(
args
.
output_dir
,
'SegmentationClassPNG'
,
base
+
'.png'
)
...
...
@@ -87,7 +90,7 @@ def lblsave(filename, lbl):
# Assume label ranses [-1, 254] for int32,
# and [0, 255] for uint8 as VOC.
if
lbl
.
min
()
>=
-
1
and
lbl
.
max
()
<
255
:
lbl
=
np
.
array
([
1
if
lbl
[
x
,
y
]
>
0
else
0
for
x
in
range
(
200
)
for
y
in
range
(
200
)])
.
reshape
([
200
,
200
])
# lbl = np.array([1 if lbl[x,y]>0 else 0 for x in range(200) for y in range(200)]).reshape([200,
200])
lbl_pil
=
PIL
.
Image
.
fromarray
(
lbl
.
astype
(
np
.
uint8
),
mode
=
'P'
)
colormap
=
imgviz
.
label_colormap
()
lbl_pil
.
putpalette
(
colormap
.
flatten
())
...
...
data/module/MODEL.pth
View file @
ca83308e
No preview for this file type
mrnet/train.py
View file @
ca83308e
...
...
@@ -14,24 +14,11 @@ from utils.dataset import BasicDataset,VOCSegmentation
from
utils.eval
import
eval_net
dir_img
=
'data/train_imgs/'
dir_mask
=
'data/train_masks/'
dir_checkpoint
=
'checkpoint/'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
):
# dataset = BasicDataset(dir_img, dir_mask)
# n_val = int(len(dataset) * val_percent)
# n_train = len(dataset) - n_val
# train, val = random_split(dataset, [n_train, n_val])
# train_loader = DataLoader(train, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory = True)
# val_loader = DataLoader(val, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory = True)
trans
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
ToTensor
()
])
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
):
trans
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
ToTensor
()])
trainset
=
VOCSegmentation
(
'data'
,
'train'
,
trans
,
trans
)
evalset
=
VOCSegmentation
(
'data'
,
'traineval'
,
trans
,
trans
)
n_train
=
len
(
trainset
)
...
...
@@ -47,9 +34,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss
=
0
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
imgs
,
true_masks
in
train_loader
:
# imgs = batch['image']
# true_masks = batch['mask']
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
...
...
train.py
View file @
ca83308e
...
...
@@ -63,7 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try
:
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
val_percent
=
args
.
val
/
100
)
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
)
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
logging
.
info
(
'Saved interrupt'
)
...
...
unet/train.py
View file @
ca83308e
...
...
@@ -8,12 +8,13 @@ import sys
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
torchvision
import
transforms
from
tqdm
import
tqdm
from
utils.eval
import
eval_net
from
torch.utils.tensorboard
import
SummaryWriter
from
utils.dataset
import
BasicDataset
from
utils.dataset
import
BasicDataset
,
VOCSegmentation
from
torch.utils.data
import
DataLoader
,
random_split
dir_img
=
'data/train_imgs/'
...
...
@@ -21,15 +22,16 @@ dir_mask = 'data/train_masks/'
dir_checkpoint
=
'checkpoints/'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
train
,
val
=
random_split
(
dataset
,
[
n_train
,
n_val
])
train_loader
=
DataLoader
(
train
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
save_cp
=
True
):
trans
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
ToTensor
()])
trainset
=
VOCSegmentation
(
'data'
,
'train'
,
trans
,
trans
)
evalset
=
VOCSegmentation
(
'data'
,
'traineval'
,
trans
,
trans
)
n_train
=
len
(
trainset
)
n_val
=
len
(
evalset
)
train_loader
=
DataLoader
(
trainset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
evalset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
writer
=
SummaryWriter
(
comment
=
f
'LR_{lr}_BS_{batch_size}
_SCALE_{img_scale}
'
)
writer
=
SummaryWriter
(
comment
=
f
'LR_{lr}_BS_{batch_size}'
)
global_step
=
0
logging
.
info
(
f
'''Starting training:
...
...
@@ -40,7 +42,6 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
'''
)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
...
...
@@ -56,9 +57,9 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
epoch_loss
=
0
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
batch
in
train_loader
:
imgs
=
batch
[
'image'
]
true_masks
=
batch
[
'mask'
]
for
imgs
,
true_masks
in
train_loader
:
#
imgs = batch['image']
#
true_masks = batch['mask']
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
...
...
@@ -95,7 +96,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
torch
.
sigmoid
(
masks_pred
)
>
0.5
,
global_step
)
if
save_cp
and
epoch
%
4
==
0
:
if
save_cp
and
epoch
%
5
==
0
:
try
:
os
.
mkdir
(
dir_checkpoint
)
logging
.
info
(
'Created checkpoint directory'
)
...
...
utils/dataset.py
View file @
ca83308e
...
...
@@ -8,6 +8,9 @@ import torch
from
torch.utils.data
import
Dataset
import
logging
from
PIL
import
Image
import
imgaug
as
ia
import
imgaug.augmenters
as
iaa
from
imgaug.augmentables.segmaps
import
SegmentationMapsOnImage
import
os
from
torchvision.datasets.vision
import
VisionDataset
...
...
@@ -73,7 +76,7 @@ class VOCSegmentation(VisionDataset):
mask_dir
=
os
.
path
.
join
(
voc_root
,
'SegmentationClassPNG'
)
if
not
os
.
path
.
isdir
(
voc_root
):
raise
RuntimeError
(
'Dataset not found or corrupted.'
+
' You can use download=True to download it'
)
raise
RuntimeError
(
'Dataset not found or corrupted.'
)
split_f
=
os
.
path
.
join
(
voc_root
,
image_set
.
rstrip
(
'
\n
'
)
+
'.txt'
)
...
...
@@ -84,45 +87,26 @@ class VOCSegmentation(VisionDataset):
self
.
masks
=
[
os
.
path
.
join
(
mask_dir
,
x
+
".png"
)
for
x
in
file_names
]
assert
(
len
(
self
.
images
)
==
len
(
self
.
masks
))
@classmethod
def
preprocess
(
cls
,
pil_img
):
pil_img
=
pil_img
.
resize
((
256
,
256
))
img_nd
=
np
.
array
(
pil_img
)
if
len
(
img_nd
.
shape
)
==
2
:
img_nd
=
np
.
expand_dims
(
img_nd
,
axis
=
2
)
# HWC to CHW
img_trans
=
img_nd
.
transpose
((
2
,
0
,
1
))
if
img_trans
.
max
()
>
1
:
img_trans
=
img_trans
/
255
return
img_trans
self
.
seq
=
iaa
.
Sequential
([
iaa
.
SomeOf
((
0
,
5
),
[
iaa
.
Noop
(),
iaa
.
Fliplr
(
0.5
),
iaa
.
Sometimes
(
0.25
,
iaa
.
Dropout
(
p
=
(
0
,
0.1
))),
iaa
.
Affine
(
rotate
=
(
-
45
,
45
)),
iaa
.
ElasticTransformation
(
alpha
=
50
,
sigma
=
5
)
],
random_order
=
True
)])
def
__getitem__
(
self
,
index
):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img
=
Image
.
open
(
self
.
images
[
index
])
.
convert
(
'L'
)
target
=
Image
.
open
(
self
.
masks
[
index
])
.
convert
(
'L'
)
pim
=
target
.
load
()
for
i
in
range
(
200
):
for
j
in
range
(
200
):
pim
[
i
,
j
]
=
1
if
pim
[
i
,
j
]
>
0
else
0
pim
[
i
,
j
]
=
1
if
pim
[
i
,
j
]
>
0
else
0
# img, target = self.seq(image=np.array(img), segmentation_maps = np.array(target))
if
self
.
transforms
is
not
None
:
img
,
target
=
self
.
transforms
(
img
,
target
)
# img = self.preprocess(img)
# target = self.preprocess(target)
return
img
,
target
#return {'image':torch.from_numpy(np.asarray(img)), 'mask':torch.from_numpy(np.asarray(target))}
def
__len__
(
self
):
return
len
(
self
.
images
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment