Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
文档服务地址:
http://47.92.0.57:3000/
周报索引地址:
http://47.92.0.57:3000/s/NruNXRYmV
Open sidebar
王肇一
Im
Commits
9b02e4db
Commit
9b02e4db
authored
Feb 05, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mrnet
parent
ca3e4a87
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
496 additions
and
189 deletions
+496
-189
app.py
app.py
+8
-7
32to8.ipynb
cli/32to8.ipynb
+0
-0
__init__.py
mrnet/__init__.py
+4
-2
mrnet_module.py
mrnet/mrnet_module.py
+42
-62
mrnet_parts.py
mrnet/mrnet_parts.py
+72
-2
train.py
mrnet/train.py
+60
-2
resCalc.py
resCalc.py
+7
-10
resUnet.py
resUnet.py
+182
-2
train.py
train.py
+10
-98
__init__.py
unet/__init__.py
+2
-0
train.py
unet/train.py
+107
-2
dataset.py
utils/dataset.py
+2
-2
No files found.
app.py
View file @
9b02e4db
...
...
@@ -42,7 +42,7 @@ def step_1(net, args, device, list, position):
cv
.
imwrite
(
'data/masks/'
+
fn
[
0
]
+
'/'
+
fn
[
1
],
result
)
def
step_2
(
list
,
position
):
def
step_2
(
list
,
position
=
1
):
for
num
,
dir
in
enumerate
(
list
):
#df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values
=
[]
...
...
@@ -52,20 +52,21 @@ def step_2(list, position):
img
=
cv
.
imread
(
'data/imgs/'
+
dir
+
'/'
+
name
,
0
)
mask
=
cv
.
imread
(
'data/masks/'
+
dir
+
'/'
+
name
,
0
)
value
=
get_subarea_info
(
img
,
mask
)
value
,
count
=
get_subarea_info
(
img
,
mask
)
ug
=
0.0
if
str
.
lower
(
match_group
.
group
(
1
))
.
endswith
(
'ug'
):
ug
=
float
(
str
.
lower
(
match_group
.
group
(
1
))[:
-
2
])
elif
str
.
lower
(
match_group
.
group
(
1
))
==
'd2o'
:
ug
=
0
elif
str
.
lower
(
match_group
.
group
(
1
))
==
'd2o'
:
ug
=
0
elif
str
.
lower
(
match_group
.
group
(
1
))
==
'lb'
:
ug
=
-
1
iter
=
str
.
lower
(
match_group
.
group
(
2
))
values
.
append
({
'Intensity (a. u.)'
:
value
,
'ug'
:
ug
,
'iter'
:
iter
})
values
.
append
({
'Intensity (a. u.)'
:
value
,
'ug'
:
ug
,
'count'
:
count
})
df
=
pd
.
DataFrame
(
values
)
df
.
sort_values
(
'ug'
,
inplace
=
True
)
baseline
=
df
[
df
[
'ug'
]
==
0
][
'Intensity (a. u.)'
]
.
mean
()
*
0.62
df
.
replace
(
-
1
,
'lb'
,
inplace
=
True
)
df
.
replace
(
0
,
'd2o'
,
inplace
=
True
)
baseline
=
df
[
df
[
'ug'
]
==
'd2o'
][
'Intensity (a. u.)'
]
.
mean
()
*
0.62
sns
.
set_style
(
"darkgrid"
)
sns
.
catplot
(
x
=
'ug'
,
y
=
'Intensity (a. u.)'
,
kind
=
'bar'
,
palette
=
'vlag'
,
data
=
df
)
#sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
...
...
cli/32to8.ipynb
View file @
9b02e4db
This diff is collapsed.
Click to expand it.
mrnet/__init__.py
View file @
9b02e4db
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
from
.train
import
train_net
from
.mrnet_module
import
MultiUnet
\ No newline at end of file
mrnet/mrnet_module.py
View file @
9b02e4db
...
...
@@ -2,8 +2,8 @@
# -*- coding:utf-8 -*-
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
.mrnet_parts
import
MultiResBlock
,
ResPath
,
TransCompose
class
MultiUnet
(
nn
.
Module
):
...
...
@@ -14,80 +14,60 @@ class MultiUnet(nn.Module):
self
.
bilinear
=
bilinear
self
.
inconv
=
nn
.
Sequential
(
nn
.
Conv2d
(
n_channels
,
16
,
kernel_size
=
3
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
Conv2d
(
n_channels
,
8
,
kernel_size
=
3
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
8
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
res1
=
MultiResBlock
(
16
,
32
)
self
.
res1
=
MultiResBlock
(
8
,
32
)
self
.
res2
=
MultiResBlock
(
self
.
res1
.
outc
,
32
*
2
)
self
.
res3
=
MultiResBlock
(
self
.
res2
.
outc
,
32
*
4
)
self
.
res4
=
MultiResBlock
(
self
.
res3
.
outc
,
32
*
8
)
self
.
res5
=
MultiResBlock
(
self
.
res4
.
outc
,
32
*
16
)
self
.
res6
=
MultiResBlock
(
self
.
res5
.
outc
,
32
*
8
)
self
.
res7
=
MultiResBlock
(
self
.
res6
.
outc
,
32
*
4
)
self
.
res8
=
MultiResBlock
(
self
.
res7
.
outc
,
32
*
2
)
self
.
res9
=
MultiResBlock
(
self
.
res8
.
outc
,
32
)
self
.
path1_9
=
ResPath
(
self
.
res1
.
outc
,
32
,
4
)
self
.
path2_8
=
ResPath
(
self
.
res2
.
outc
,
32
*
2
,
3
)
self
.
path3_7
=
ResPath
(
self
.
res3
.
outc
,
32
*
4
,
2
)
self
.
path4_6
=
ResPath
(
self
.
res4
.
outc
,
32
*
8
,
1
)
self
.
up6
=
TransCompose
(
self
.
res5
.
outc
,
32
*
8
)
self
.
res6
=
MultiResBlock
(
self
.
up6
.
outc
*
2
,
32
*
8
)
self
.
up7
=
TransCompose
(
self
.
res6
.
outc
,
32
*
4
)
self
.
res7
=
MultiResBlock
(
self
.
up7
.
outc
*
2
,
32
*
4
)
self
.
up8
=
TransCompose
(
self
.
res7
.
outc
,
32
*
2
)
self
.
res8
=
MultiResBlock
(
self
.
up8
.
outc
*
2
,
32
*
2
)
self
.
up9
=
TransCompose
(
self
.
res8
.
outc
,
32
)
self
.
res9
=
MultiResBlock
(
self
.
up9
.
outc
*
2
,
32
)
self
.
pool
=
nn
.
MaxPool2d
(
2
)
self
.
outconv
=
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
3
,
padding_mode
=
'same'
)
def
forward
(
self
,
x
):
pass
x
=
self
.
inconv
(
x
)
res1
=
self
.
res1
(
x
)
pool1
=
self
.
pool
(
res1
)
res2
=
self
.
res2
(
pool1
)
pool2
=
self
.
pool
(
res2
)
class
MultiResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
U
,
alpha
=
1.67
):
super
()
.
__init__
()
self
.
U
=
U
self
.
W
=
U
*
alpha
self
.
inc
=
in_channels
self
.
outc
=
int
(
self
.
W
*
0.167
)
+
int
(
self
.
W
*
0.333
)
+
int
(
self
.
W
*
0.5
)
res3
=
self
.
res3
(
pool2
)
pool3
=
self
.
pool
(
res3
)
def
conv
(
self
,
in_channel
,
out_channel
,
kernel_size
):
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
)
)
res4
=
self
.
res4
(
pool3
)
pool4
=
self
.
pool
(
res4
)
def
forward
(
self
,
x
):
shortcut
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
1
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv3
=
self
.
conv
(
self
.
inc
,
int
(
self
.
W
*
0.167
),
3
)(
x
)
conv5
=
self
.
conv
(
int
(
self
.
W
*
0.167
),
int
(
self
.
W
*
0.333
),
3
)(
conv3
)
conv7
=
self
.
conv
(
int
(
self
.
W
*
0.333
),
int
(
self
.
W
*
0.5
),
3
)(
conv5
)
result
=
torch
.
cat
([
conv3
,
conv5
,
conv7
],
2
)
result
=
nn
.
BatchNorm2d
(
self
.
outc
)(
result
)
result
=
torch
.
add
(
result
,
shortcut
)
result
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
self
.
outc
))(
result
)
return
result
class
ResPath
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
length
):
super
()
.
__init__
()
self
.
length
=
length
self
.
inc
=
in_channels
self
.
outc
=
out_channels
def
unit
(
self
,
in_channels
,
out_channels
,
x
):
shortcut
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
result
=
torch
.
add
(
conv
,
shortcut
)
return
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
out_channels
))(
result
)
res5
=
self
.
res5
(
pool4
)
def
forward
(
self
,
x
):
x
=
self
.
unit
(
self
.
inc
,
self
.
outc
,
x
)
for
i
in
range
(
self
.
length
-
1
):
x
=
self
.
unit
(
self
.
outc
,
self
.
outc
,
x
)
return
x
res1
=
self
.
path1_9
(
res1
)
res2
=
self
.
path2_8
(
res2
)
res3
=
self
.
path3_7
(
res3
)
res4
=
self
.
path4_6
(
res4
)
res6
=
self
.
res6
(
torch
.
cat
([
res4
,
self
.
up6
(
res5
)],
1
))
res7
=
self
.
res7
(
torch
.
cat
([
res3
,
self
.
up7
(
res6
)],
1
))
res8
=
self
.
res8
(
torch
.
cat
([
res2
,
self
.
up8
(
res7
)],
1
))
res9
=
self
.
res9
(
torch
.
cat
([
res1
,
self
.
up9
(
res8
)],
1
))
out
=
self
.
outconv
(
res9
)
return
out
mrnet/mrnet_parts.py
View file @
9b02e4db
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import
torch
import
torchvision
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
MultiResBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
U
,
alpha
=
1.67
):
super
()
.
__init__
()
self
.
U
=
U
self
.
W
=
U
*
alpha
self
.
inc
=
in_channels
self
.
outc
=
int
(
self
.
W
*
0.167
)
+
int
(
self
.
W
*
0.333
)
+
int
(
self
.
W
*
0.5
)
def
conv
(
self
,
in_channel
,
out_channel
,
kernel_size
):
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
1
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
shortcut
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
1
,
stride
=
1
),
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv3
=
self
.
conv
(
self
.
inc
,
int
(
self
.
W
*
0.167
),
3
)(
x
)
conv5
=
self
.
conv
(
int
(
self
.
W
*
0.167
),
int
(
self
.
W
*
0.333
),
3
)(
conv3
)
conv7
=
self
.
conv
(
int
(
self
.
W
*
0.333
),
int
(
self
.
W
*
0.5
),
3
)(
conv5
)
result
=
torch
.
cat
([
conv3
,
conv5
,
conv7
],
1
)
result
=
nn
.
BatchNorm2d
(
self
.
outc
)(
result
)
result
=
torch
.
add
(
result
,
shortcut
)
result
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
self
.
outc
))(
result
)
return
result
class
TransCompose
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
()
.
__init__
()
self
.
inc
=
in_channels
self
.
outc
=
out_channels
def
forward
(
self
,
x
):
return
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
3
,
stride
=
2
),
nn
.
BatchNorm2d
(
self
.
outc
)
)(
x
)
class
ResPath
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
length
):
super
()
.
__init__
()
self
.
length
=
length
self
.
inc
=
in_channels
self
.
outc
=
out_channels
def
unit
(
self
,
in_channels
,
out_channels
,
x
):
shortcut
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
),
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
result
=
torch
.
add
(
conv
,
shortcut
)
return
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
out_channels
))(
result
)
def
forward
(
self
,
x
):
x
=
self
.
unit
(
self
.
inc
,
self
.
outc
,
x
)
for
i
in
range
(
self
.
length
-
1
):
x
=
self
.
unit
(
self
.
outc
,
self
.
outc
,
x
)
return
x
mrnet/train.py
View file @
9b02e4db
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import
os
import
logging
from
tqdm
import
tqdm
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
torch.utils.data
import
DataLoader
,
random_split
from
utils.dataset
import
BasicDataset
from
utils.eval
import
eval_net
dir_img
=
'data/train_imgs/'
dir_mask
=
'data/train_masks/'
dir_checkpoint
=
'checkpoint'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
train
,
val
=
random_split
(
dataset
,
[
n_train
,
n_val
])
train_loader
=
DataLoader
(
train
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
criterion
=
nn
.
BCELoss
()
for
epoch
in
tqdm
(
range
(
epochs
)):
net
.
train
()
epoch_loss
=
0
for
batch
in
train_loader
:
imgs
=
batch
[
'image'
]
true_masks
=
batch
[
'mask'
]
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
masks_pred
=
net
(
imgs
)
loss
=
criterion
(
masks_pred
,
true_masks
)
epoch_loss
+=
loss
.
item
()
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
val_score
=
eval_net
(
net
,
val_loader
,
device
,
n_val
)
logging
.
info
(
'Validation cross entropy: {}'
.
format
(
val_score
))
try
:
os
.
mkdir
(
dir_checkpoint
)
logging
.
info
(
'Created checkpoint directory'
)
except
OSError
:
pass
torch
.
save
(
net
.
state_dict
(),
dir_checkpoint
+
f
'CP_epoch{epoch + 1}.pth'
)
torch
.
save
(
net
.
state_dict
(),
'MODEL.pth'
)
resCalc.py
View file @
9b02e4db
...
...
@@ -65,21 +65,18 @@ def get_subarea_info(img, mask):
back_value
=
img
[
np
.
where
(
real_bg_mask
!=
0
)]
back_mean
=
np
.
mean
(
back_value
)
info
.
append
({
'id'
:
i
,
'size'
:
area_size
,
'area_mean'
:
area_mean
,
'back_mean'
:
back_mean
})
info
.
append
({
'mean'
:
area_mean
,
'back'
:
back_mean
,
'size'
:
area_size
})
# endif
df
=
pd
.
DataFrame
(
info
)
df
[
'Intensity (a. u.)'
]
=
df
[
'area_mean'
]
-
df
[
'back_mean'
]
median
=
np
.
median
(
df
[
'
Intensity (a. u.)
'
])
median
=
np
.
median
(
df
[
'
mean
'
])
b
=
1.4826
mad
=
b
*
np
.
median
(
np
.
abs
(
df
[
'
Intensity (a. u.)
'
]
-
median
))
mad
=
b
*
np
.
median
(
np
.
abs
(
df
[
'
mean
'
]
-
median
))
lower_limit
=
median
-
(
3
*
mad
)
upper_limit
=
median
+
(
3
*
mad
)
#df = df[df['Intensity (a. u.)'] > lower_limit]
#df = df[df['Intensity (a. u.)'] < upper_limit]
value
=
df
[
'Intensity (a. u.)'
]
.
mean
()
df
=
df
[
df
[
'mean'
]
>=
lower_limit
]
df
=
df
[
df
[
'mean'
]
<=
upper_limit
]
df
[
'value'
]
=
df
[
'mean'
]
-
df
[
'back'
]
return
value
return
(
df
[
'value'
]
*
df
[
'size'
])
.
sum
()
/
df
[
'size'
]
.
sum
(),
df
.
shape
[
0
]
resUnet.py
View file @
9b02e4db
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
from
torch
import
nn
from
torch.nn
import
functional
as
F
import
torch
from
torchvision
import
models
import
torchvision
def
conv3x3
(
in_
,
out
):
return
nn
.
Conv2d
(
in_
,
out
,
3
,
padding
=
1
)
class
ConvRelu
(
nn
.
Module
):
def
__init__
(
self
,
in_
,
out
):
super
()
.
__init__
()
self
.
conv
=
conv3x3
(
in_
,
out
)
self
.
activation
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
activation
(
x
)
return
x
class
DecoderBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
middle_channels
,
out_channels
):
super
()
.
__init__
()
self
.
block
=
nn
.
Sequential
(
ConvRelu
(
in_channels
,
middle_channels
),
nn
.
ConvTranspose2d
(
middle_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
output_padding
=
1
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
Interpolate
(
nn
.
Module
):
def
__init__
(
self
,
size
=
None
,
scale_factor
=
None
,
mode
=
'nearest'
,
align_corners
=
False
):
super
(
Interpolate
,
self
)
.
__init__
()
self
.
interp
=
nn
.
functional
.
interpolate
self
.
size
=
size
self
.
mode
=
mode
self
.
scale_factor
=
scale_factor
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
):
x
=
self
.
interp
(
x
,
size
=
self
.
size
,
scale_factor
=
self
.
scale_factor
,
mode
=
self
.
mode
,
align_corners
=
self
.
align_corners
)
return
x
class
DecoderBlockV2
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
middle_channels
,
out_channels
,
is_deconv
=
True
):
super
(
DecoderBlockV2
,
self
)
.
__init__
()
self
.
in_channels
=
in_channels
if
is_deconv
:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self
.
block
=
nn
.
Sequential
(
ConvRelu
(
in_channels
,
middle_channels
),
nn
.
ConvTranspose2d
(
middle_channels
,
out_channels
,
kernel_size
=
4
,
stride
=
2
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
))
else
:
self
.
block
=
nn
.
Sequential
(
Interpolate
(
scale_factor
=
2
,
mode
=
'bilinear'
),
ConvRelu
(
in_channels
,
middle_channels
),
ConvRelu
(
middle_channels
,
out_channels
),
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
class
UNet11
(
nn
.
Module
):
def
__init__
(
self
,
num_filters
=
32
,
pretrained
=
False
):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network is used
True - encoder is pre-trained with VGG11
"""
super
()
.
__init__
()
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
encoder
=
models
.
vgg11
(
pretrained
=
pretrained
)
.
features
self
.
relu
=
self
.
encoder
[
1
]
self
.
conv1
=
self
.
encoder
[
0
]
self
.
conv2
=
self
.
encoder
[
3
]
self
.
conv3s
=
self
.
encoder
[
6
]
self
.
conv3
=
self
.
encoder
[
8
]
self
.
conv4s
=
self
.
encoder
[
11
]
self
.
conv4
=
self
.
encoder
[
13
]
self
.
conv5s
=
self
.
encoder
[
16
]
self
.
conv5
=
self
.
encoder
[
18
]
self
.
center
=
DecoderBlock
(
num_filters
*
8
*
2
,
num_filters
*
8
*
2
,
num_filters
*
8
)
self
.
dec5
=
DecoderBlock
(
num_filters
*
(
16
+
8
),
num_filters
*
8
*
2
,
num_filters
*
8
)
self
.
dec4
=
DecoderBlock
(
num_filters
*
(
16
+
8
),
num_filters
*
8
*
2
,
num_filters
*
4
)
self
.
dec3
=
DecoderBlock
(
num_filters
*
(
8
+
4
),
num_filters
*
4
*
2
,
num_filters
*
2
)
self
.
dec2
=
DecoderBlock
(
num_filters
*
(
4
+
2
),
num_filters
*
2
*
2
,
num_filters
)
self
.
dec1
=
ConvRelu
(
num_filters
*
(
2
+
1
),
num_filters
)
self
.
final
=
nn
.
Conv2d
(
num_filters
,
1
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
conv1
=
self
.
relu
(
self
.
conv1
(
x
))
conv2
=
self
.
relu
(
self
.
conv2
(
self
.
pool
(
conv1
)))
conv3s
=
self
.
relu
(
self
.
conv3s
(
self
.
pool
(
conv2
)))
conv3
=
self
.
relu
(
self
.
conv3
(
conv3s
))
conv4s
=
self
.
relu
(
self
.
conv4s
(
self
.
pool
(
conv3
)))
conv4
=
self
.
relu
(
self
.
conv4
(
conv4s
))
conv5s
=
self
.
relu
(
self
.
conv5s
(
self
.
pool
(
conv4
)))
conv5
=
self
.
relu
(
self
.
conv5
(
conv5s
))
center
=
self
.
center
(
self
.
pool
(
conv5
))
dec5
=
self
.
dec5
(
torch
.
cat
([
center
,
conv5
],
1
))
dec4
=
self
.
dec4
(
torch
.
cat
([
dec5
,
conv4
],
1
))
dec3
=
self
.
dec3
(
torch
.
cat
([
dec4
,
conv3
],
1
))
dec2
=
self
.
dec2
(
torch
.
cat
([
dec3
,
conv2
],
1
))
dec1
=
self
.
dec1
(
torch
.
cat
([
dec2
,
conv1
],
1
))
return
self
.
final
(
dec1
)
class
UNet16
(
nn
.
Module
):
def
__init__
(
self
,
num_classes
=
1
,
num_filters
=
32
,
pretrained
=
False
,
is_deconv
=
False
):
"""
:param num_classes:
:param num_filters:
:param pretrained:
False - no pre-trained network used
True - encoder pre-trained with VGG16
:is_deconv:
False: bilinear interpolation is used in decoder
True: deconvolution is used in decoder
"""
super
()
.
__init__
()
self
.
n_classes
=
num_classes
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
encoder
=
torchvision
.
models
.
vgg16
(
pretrained
=
pretrained
)
.
features
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv1
=
nn
.
Sequential
(
self
.
encoder
[
0
],
self
.
relu
,
self
.
encoder
[
2
],
self
.
relu
)
self
.
conv2
=
nn
.
Sequential
(
self
.
encoder
[
5
],
self
.
relu
,
self
.
encoder
[
7
],
self
.
relu
)
self
.
conv3
=
nn
.
Sequential
(
self
.
encoder
[
10
],
self
.
relu
,
self
.
encoder
[
12
],
self
.
relu
,
self
.
encoder
[
14
],
self
.
relu
)
self
.
conv4
=
nn
.
Sequential
(
self
.
encoder
[
17
],
self
.
relu
,
self
.
encoder
[
19
],
self
.
relu
,
self
.
encoder
[
21
],
self
.
relu
)
self
.
conv5
=
nn
.
Sequential
(
self
.
encoder
[
24
],
self
.
relu
,
self
.
encoder
[
26
],
self
.
relu
,
self
.
encoder
[
28
],
self
.
relu
)
self
.
center
=
DecoderBlock
(
512
,
num_filters
*
8
*
2
,
num_filters
*
8
)
self
.
dec5
=
DecoderBlock
(
512
+
num_filters
*
8
,
num_filters
*
8
*
2
,
num_filters
*
8
)
self
.
dec4
=
DecoderBlock
(
512
+
num_filters
*
8
,
num_filters
*
8
*
2
,
num_filters
*
8
)
self
.
dec3
=
DecoderBlock
(
256
+
num_filters
*
8
,
num_filters
*
4
*
2
,
num_filters
*
2
)
self
.
dec2
=
DecoderBlock
(
128
+
num_filters
*
2
,
num_filters
*
2
*
2
,
num_filters
)
self
.
dec1
=
ConvRelu
(
64
+
num_filters
,
num_filters
)
self
.
final
=
nn
.
Conv2d
(
num_filters
,
num_classes
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
conv1
=
self
.
conv1
(
x
)
conv2
=
self
.
conv2
(
self
.
pool
(
conv1
))
conv3
=
self
.
conv3
(
self
.
pool
(
conv2
))
conv4
=
self
.
conv4
(
self
.
pool
(
conv3
))
conv5
=
self
.
conv5
(
self
.
pool
(
conv4
))
center
=
self
.
center
(
self
.
pool
(
conv5
))
dec5
=
self
.
dec5
(
torch
.
cat
([
center
,
conv5
],
1
))
dec4
=
self
.
dec4
(
torch
.
cat
([
dec5
,
conv4
],
1
))
dec3
=
self
.
dec3
(
torch
.
cat
([
dec4
,
conv3
],
1
))
dec2
=
self
.
dec2
(
torch
.
cat
([
dec3
,
conv2
],
1
))
dec1
=
self
.
dec1
(
torch
.
cat
([
dec2
,
conv1
],
1
))
if
self
.
n_classes
>
1
:
x_out
=
F
.
log_softmax
(
self
.
final
(
dec1
),
dim
=
1
)
else
:
x_out
=
self
.
final
(
dec1
)
return
x_out
train.py
View file @
9b02e4db
...
...
@@ -4,112 +4,21 @@ import os
import
sys
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
tqdm
import
tqdm
from
utils.eval
import
eval_net
import
unet
import
mrnet
from
mrnet
import
MultiUnet
from
unet
import
UNet
from
torch.utils.tensorboard
import
SummaryWriter
from
utils.dataset
import
BasicDataset
from
torch.utils.data
import
DataLoader
,
random_split
dir_img
=
'data/train_imgs/'
dir_mask
=
'data/train_masks/'
dir_checkpoint
=
'checkpoints/'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
train
,
val
=
random_split
(
dataset
,
[
n_train
,
n_val
])
train_loader
=
DataLoader
(
train
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
writer
=
SummaryWriter
(
comment
=
f
'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}'
)
global_step
=
0
logging
.
info
(
f
'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
'''
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
,
weight_decay
=
1e-8
)
#optimizer = optim.RMSprop(net.parameters(), lr = lr, weight_decay = 1e-8)
if
net
.
n_classes
>
1
:
criterion
=
nn
.
CrossEntropyLoss
()
else
:
criterion
=
nn
.
BCEWithLogitsLoss
()
for
epoch
in
range
(
epochs
):
net
.
train
()
epoch_loss
=
0
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
batch
in
train_loader
:
imgs
=
batch
[
'image'
]
true_masks
=
batch
[
'mask'
]
assert
imgs
.
shape
[
1
]
==
net
.
n_channels
,
\
f
'Network has been defined with {net.n_channels} input channels, '
\
f
'but loaded images have {imgs.shape[1]} channels. Please check that '
\
'the images are loaded correctly.'
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
masks_pred
=
net
(
imgs
)
loss
=
criterion
(
masks_pred
,
true_masks
)
epoch_loss
+=
loss
.
item
()
writer
.
add_scalar
(
'Loss/train'
,
loss
.
item
(),
global_step
)
pbar
.
set_postfix
(
**
{
'loss (batch)'
:
loss
.
item
()})
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pbar
.
update
(
imgs
.
shape
[
0
])
global_step
+=
1
if
global_step
%
(
len
(
dataset
)
//
(
10
*
batch_size
))
==
0
:
val_score
=
eval_net
(
net
,
val_loader
,
device
,
n_val
)
if
net
.
n_classes
>
1
:
logging
.
info
(
'Validation cross entropy: {}'
.
format
(
val_score
))
writer
.
add_scalar
(
'Loss/test'
,
val_score
,
global_step
)
else
:
logging
.
info
(
'Validation Dice Coeff: {}'
.
format
(
val_score
))
writer
.
add_scalar
(
'Dice/test'
,
val_score
,
global_step
)
writer
.
add_images
(
'images'
,
imgs
,
global_step
)
if
net
.
n_classes
==
1
:
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
torch
.
sigmoid
(
masks_pred
)
>
0.5
,
global_step
)
if
save_cp
:
try
:
os
.
mkdir
(
dir_checkpoint
)
logging
.
info
(
'Created checkpoint directory'
)
except
OSError
:
pass
torch
.
save
(
net
.
state_dict
(),
dir_checkpoint
+
f
'CP_epoch{epoch + 1}.pth'
)
logging
.
info
(
f
'Checkpoint {epoch + 1} saved !'
)
writer
.
close
()
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Train the UNet on images and target masks'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
'-e'
,
'--epochs'
,
metavar
=
'E'
,
type
=
int
,
default
=
5
,
help
=
'Number of epochs'
,
parser
.
add_argument
(
'-e'
,
'--epochs'
,
metavar
=
'E'
,
type
=
int
,
default
=
1
,
help
=
'Number of epochs'
,
dest
=
'epochs'
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
metavar
=
'B'
,
type
=
int
,
nargs
=
'?'
,
default
=
1
,
help
=
'Batch size'
,
dest
=
'batchsize'
)
...
...
@@ -117,7 +26,7 @@ def get_args():
help
=
'Learning rate'
,
dest
=
'lr'
)
parser
.
add_argument
(
'-f'
,
'--load'
,
dest
=
'load'
,
type
=
str
,
default
=
False
,
help
=
'Load model from a .pth file'
)
parser
.
add_argument
(
'-s'
,
'--scale'
,
dest
=
'scale'
,
type
=
float
,
default
=
0.5
,
parser
.
add_argument
(
'-s'
,
'--scale'
,
dest
=
'scale'
,
type
=
float
,
default
=
2.56
,
help
=
'Downscaling factor of the images'
)
parser
.
add_argument
(
'-v'
,
'--validation'
,
dest
=
'val'
,
type
=
float
,
default
=
10.0
,
help
=
'Percent of the data that is used as validation (0-100)'
)
...
...
@@ -137,7 +46,10 @@ if __name__ == '__main__':
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
#net = UNet(n_channels = 1, n_classes = 1)
net
=
MultiUnet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
f
'Network:
\n
'
f
'
\t
{net.n_channels} input channels
\n
'
f
'
\t
{net.n_classes} output channels (classes)
\n
'
...
...
@@ -152,7 +64,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try
:
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
...
...
unet/__init__.py
View file @
9b02e4db
from
.unet_model
import
UNet
from
.train
import
train_net
\ No newline at end of file
unet/train.py
View file @
9b02e4db
#!/usr/bin/env python
# -*- coding:utf-8 -*-
\ No newline at end of file
# -*- coding:utf-8 -*-
import
argparse
import
logging
import
os
import
sys
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
tqdm
import
tqdm
from
utils.eval
import
eval_net
from
torch.utils.tensorboard
import
SummaryWriter
from
utils.dataset
import
BasicDataset
from
torch.utils.data
import
DataLoader
,
random_split
dir_img
=
'data/train_imgs/'
dir_mask
=
'data/train_masks/'
dir_checkpoint
=
'checkpoints/'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
train
,
val
=
random_split
(
dataset
,
[
n_train
,
n_val
])
train_loader
=
DataLoader
(
train
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
writer
=
SummaryWriter
(
comment
=
f
'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}'
)
global_step
=
0
logging
.
info
(
f
'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Device: {device.type}
Images scaling: {img_scale}
'''
)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer
=
optim
.
RMSprop
(
net
.
parameters
(),
lr
=
lr
,
weight_decay
=
1e-8
)
# criterion = nn.BCEWithLogitsLoss()
if
net
.
n_classes
>
1
:
criterion
=
nn
.
CrossEntropyLoss
()
else
:
criterion
=
nn
.
BCEWithLogitsLoss
()
for
epoch
in
range
(
epochs
):
net
.
train
()
epoch_loss
=
0
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
batch
in
train_loader
:
imgs
=
batch
[
'image'
]
true_masks
=
batch
[
'mask'
]
# assert imgs.shape[1] == net.n_channels, \
# f'Network has been defined with {net.n_channels} input channels, ' \
# f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
masks_pred
=
net
(
imgs
)
loss
=
criterion
(
masks_pred
,
true_masks
)
epoch_loss
+=
loss
.
item
()
writer
.
add_scalar
(
'Loss/train'
,
loss
.
item
(),
global_step
)
pbar
.
set_postfix
(
**
{
'loss (batch)'
:
loss
.
item
()})
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pbar
.
update
(
imgs
.
shape
[
0
])
global_step
+=
1
# if global_step % (len(dataset) // (10 * batch_size)) == 0:
val_score
=
eval_net
(
net
,
val_loader
,
device
,
n_val
)
if
net
.
n_classes
>
1
:
logging
.
info
(
'Validation cross entropy: {}'
.
format
(
val_score
))
writer
.
add_scalar
(
'Loss/test'
,
val_score
,
global_step
)
else
:
logging
.
info
(
'Validation Dice Coeff: {}'
.
format
(
val_score
))
writer
.
add_scalar
(
'Dice/test'
,
val_score
,
global_step
)
writer
.
add_images
(
'images'
,
imgs
,
global_step
)
if
net
.
n_classes
==
1
:
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
torch
.
sigmoid
(
masks_pred
)
>
0.5
,
global_step
)
if
save_cp
and
epoch
%
4
==
0
:
try
:
os
.
mkdir
(
dir_checkpoint
)
logging
.
info
(
'Created checkpoint directory'
)
except
OSError
:
pass
torch
.
save
(
net
.
state_dict
(),
dir_checkpoint
+
f
'CP_epoch{epoch + 1}.pth'
)
logging
.
info
(
f
'Checkpoint {epoch + 1} saved !'
)
torch
.
save
(
net
.
state_dict
(),
'MODEL.pth'
)
writer
.
close
()
\ No newline at end of file
utils/dataset.py
View file @
9b02e4db
...
...
@@ -13,7 +13,7 @@ class BasicDataset(Dataset):
self
.
imgs_dir
=
imgs_dir
self
.
masks_dir
=
masks_dir
self
.
scale
=
scale
assert
0
<
scale
<=
1
,
'Scale must be between 0 and 1'
#
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self
.
ids
=
[
splitext
(
file
)[
0
]
for
file
in
listdir
(
imgs_dir
)
if
not
file
.
startswith
(
'.'
)]
...
...
@@ -25,7 +25,7 @@ class BasicDataset(Dataset):
@classmethod
def
preprocess
(
cls
,
pil_img
,
scale
):
w
,
h
=
pil_img
.
size
newW
,
newH
=
int
(
scale
*
w
),
int
(
scale
*
h
)
newW
,
newH
=
256
,
256
#
int(scale * w), int(scale * h)
assert
newW
>
0
and
newH
>
0
,
'Scale is too small'
pil_img
=
pil_img
.
resize
((
newW
,
newH
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment