Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
王肇一
Im
Commits
447cf1b5
Commit
447cf1b5
authored
Aug 04, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
202008 update
parent
597d01e2
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
168 additions
and
112 deletions
+168
-112
app.py
app.py
+21
-31
mrnet_module.py
mrnet/mrnet_module.py
+2
-2
train.py
mrnet/train.py
+25
-9
resCalc.py
resCalc.py
+41
-34
train.py
train.py
+8
-4
train.py
unet/train.py
+19
-17
dice_loss.py
utils/dice_loss.py
+18
-8
eval.py
utils/eval.py
+7
-7
focal_loss.py
utils/focal_loss.py
+27
-0
No files found.
app.py
View file @
447cf1b5
...
...
@@ -14,12 +14,13 @@ import argparse
import
logging
import
os
import
re
from
collections
import
OrderedDict
from
unet
import
UNet
from
mrnet
import
MultiUnet
from
utils.predict
import
predict_img
,
predict
from
resCalc
import
save_img
,
get_subarea_info
,
save_img_mask
,
get_subarea_info_avgBG
,
get_subarea_info_fast
,
\
get_subarea_info_
fast_outlier
from
resCalc
import
save_img
,
get_subarea_info
,
save_img_mask
,
get_subarea_info_avgBG
,
get_subarea_info_fast_outlier
,
\
get_subarea_info_
nobg
def
divide_list
(
list
,
step
):
...
...
@@ -52,7 +53,7 @@ def step_1_32bit(net,args,device,list,position):
norm
=
cv
.
normalize
(
np
.
array
(
img
),
None
,
0
,
255
,
cv
.
NORM_MINMAX
,
cv
.
CV_8U
)
norm
=
Image
.
fromarray
(
norm
.
astype
(
'uint8'
))
mask
=
predict_img
(
net
=
net
,
full_img
=
norm
,
out_threshold
=
args
.
mask_threshold
,
device
=
device
)
#
mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
#mask = predict(net = net, full_img = img, out_threshold = args.mask_threshold, device = device)
result
=
(
mask
*
255
)
.
astype
(
np
.
uint8
)
# save_img({'ori': img, 'mask': result}, fn[0], fn[1])
...
...
@@ -74,8 +75,6 @@ def step_2(list, position=1):
match_group
=
re
.
match
(
'.*
\
s([dD]2[oO]|[lL][bB]|.*ug).*
\
s(.+)
\
.tif'
,
name
)
img
=
cv
.
imread
(
'data/imgs/'
+
dir
+
'/'
+
name
,
0
)
mask
=
cv
.
imread
(
'data/masks/'
+
dir
+
'/'
+
name
,
0
)
# value = get_subarea_info_fast(img, mask)
# value, count = get_subarea_info(img, mask)
value
=
get_subarea_info_avgBG
(
img
,
mask
)
if
value
is
not
None
:
ug
=
0.0
...
...
@@ -99,7 +98,7 @@ def step_2(list, position=1):
plt
.
savefig
(
'data/output/'
+
dir
+
'.png'
)
def
step_2_32bit
(
list
,
position
=
1
):
def
step_2_32bit
(
list
,
position
=
1
):
for
num
,
dir
in
enumerate
(
list
):
# df = pd.DataFrame(columns = ('ug', 'iter', 'id', 'size', 'area_mean', 'back_mean', 'Intensity (a. u.)'))
values
=
[]
...
...
@@ -107,9 +106,14 @@ def step_2_32bit(list,position=1):
for
name
in
tqdm
(
names
,
desc
=
f
'Period{num + 1}/{len(list)}'
,
position
=
position
):
match_group
=
re
.
match
(
'.*
\
s([dD]2[oO]|[lL][bB]|.*ug).*
\
s(.+)
\
.tif'
,
name
)
img
=
cv
.
imread
(
'data/imgs/'
+
dir
+
'/'
+
name
,
flags
=
cv
.
IMREAD_ANYDEPTH
)
# img_mean = np.mean(img)
# if np.max(img) + 125 -img_mean < 255:
# img = img+ 125 - img_mean
# else:
# img = img+ 255 - np.max(img)
mask
=
cv
.
imread
(
'data/masks/'
+
dir
+
'/'
+
name
,
0
)
value
=
get_subarea_info_fast_outlier
(
img
,
mask
)
#value,shape = get_subarea_info(img,mask)
value
=
get_subarea_info_fast_outlier
(
img
,
mask
)
# get_subarea_info_nobg(img,mask)
if
value
is
not
None
:
ug
=
0.0
if
str
.
lower
(
match_group
.
group
(
1
))
.
endswith
(
'ug'
):
...
...
@@ -129,6 +133,7 @@ def step_2_32bit(list,position=1):
sns
.
set_style
(
"darkgrid"
)
sns
.
catplot
(
x
=
'ug'
,
y
=
'Intensity(a.u.)'
,
kind
=
'bar'
,
palette
=
'vlag'
,
data
=
df
)
# sns.swarmplot(x = "ug", y = "Intensity (a. u.)", data = df, size = 2, color = ".3", linewidth = 0)
plt
.
suptitle
(
dir
)
plt
.
axhline
(
y
=
baseline_high
)
plt
.
axhline
(
y
=
baseline_low
)
plt
.
savefig
(
'data/output/'
+
dir
+
'.png'
)
...
...
@@ -166,7 +171,14 @@ def cli_main():
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
net
.
to
(
device
=
device
)
net
.
load_state_dict
(
torch
.
load
(
'data/module/'
+
args
.
module
+
'.pth'
,
map_location
=
device
))
model
=
torch
.
load
(
'data/module/'
+
args
.
module
+
'.pth'
,
map_location
=
device
)
d
=
OrderedDict
()
for
key
,
value
in
model
.
items
():
tmp
=
key
[
7
:]
d
[
tmp
]
=
value
net
.
load_state_dict
(
d
)
#net.load_state_dict(torch.load('data/module/' + args.module + '.pth', map_location = device))
logging
.
info
(
"Model loaded !"
)
pool
=
Pool
(
args
.
process
)
...
...
@@ -184,28 +196,6 @@ def cli_main():
pool
.
apply_async
(
step_2_32bit
,
args
=
(
list
,
i
))
pool
.
close
()
pool
.
join
()
elif
args
.
step
==
3
:
net
=
UNet
(
n_channels
=
1
,
n_classes
=
1
)
# net = MultiUnet(n_channels = 1,n_classes = 1)
logging
.
info
(
"Loading model {}"
.
format
(
args
.
module
))
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
net
.
to
(
device
=
device
)
net
.
load_state_dict
(
torch
.
load
(
'data/module/'
+
args
.
module
+
'.pth'
,
map_location
=
device
))
logging
.
info
(
"Model loaded !"
)
pool
=
Pool
(
args
.
process
)
for
i
,
list
in
enumerate
(
seperate_path
):
pool
.
apply_async
(
step_1
,
args
=
(
net
,
args
,
device
,
list
,
i
))
pool
.
close
()
pool
.
join
()
dir
=
[
x
for
x
in
filter
(
lambda
x
:
x
!=
'.DS_Store'
,
os
.
listdir
(
'data/imgs/'
))]
sep_dir
=
divide_list
(
dir
,
args
.
process
)
for
i
,
list
in
enumerate
(
sep_dir
):
pool
.
apply_async
(
step_2
,
args
=
(
list
,
i
))
pool
.
close
()
pool
.
join
()
if
__name__
==
'__main__'
:
...
...
mrnet/mrnet_module.py
View file @
447cf1b5
...
...
@@ -42,8 +42,8 @@ class MultiUnet(nn.Module):
self
.
pool
=
nn
.
MaxPool2d
(
2
)
self
.
outconv
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
1
),
nn
.
Softmax
()
#
nn.Sigmoid()
#
nn.Softmax()
nn
.
Sigmoid
()
)
# self.outconv = nn.Conv2d(self.res9.outc, n_classes,kernel_size = 1)
...
...
mrnet/train.py
View file @
447cf1b5
...
...
@@ -4,16 +4,20 @@
import
os
import
logging
from
tqdm
import
tqdm
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch
import
optim
from
torch.optim
import
lr_scheduler
from
torch.optim.rmsprop
import
RMSprop
from
torchvision
import
transforms
from
torch.utils.data
import
DataLoader
,
random_split
from
torch.utils.tensorboard
import
SummaryWriter
from
utils.dataset
import
BasicDataset
,
VOCSegmentation
from
utils.eval
import
eval_net
,
eval_jac
from
utils.dice_loss
import
DiceLoss
from
utils.focal_loss
import
FocalLoss
dir_checkpoint
=
'checkpoint/'
...
...
@@ -27,9 +31,14 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
train_loader
=
DataLoader
(
trainset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
evalset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
criterion
=
nn
.
BCELoss
()
# nn.BCEWithLogitsLoss()
scheduler
=
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'max'
,
eps
=
1e-20
,
factor
=
0.5
,
patience
=
5
)
writer
=
SummaryWriter
(
comment
=
f
'LR_{lr}_BS_{batch_size}'
)
global_step
=
0
#optimizer = optim.Adam(net.parameters(), lr = lr)
optimizer
=
RMSprop
(
net
.
parameters
(),
lr
=
lr
,
weight_decay
=
1e-8
,
momentum
=
0.99
)
#criterion = nn.BCELoss()
criterion
=
FocalLoss
(
alpha
=
1
,
gamma
=
2
,
logits
=
False
)
scheduler
=
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'max'
,
eps
=
1e-20
,
factor
=
0.1
,
patience
=
10
)
for
epoch
in
range
(
epochs
):
net
.
train
()
...
...
@@ -37,25 +46,32 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1):
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
imgs
,
true_masks
in
train_loader
:
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
mask_type
=
torch
.
float32
#
if net.n_classes == 1 else torch.long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
optimizer
.
zero_grad
()
masks_pred
=
net
(
imgs
)
loss
=
criterion
(
masks_pred
,
true_masks
)
epoch_loss
+=
loss
.
item
()
writer
.
add_scalar
(
'Loss/train'
,
loss
.
item
(),
global_step
)
pbar
.
set_postfix
(
**
{
'loss (batch)'
:
loss
.
item
()})
loss
.
backward
()
optimizer
.
step
()
pbar
.
update
(
imgs
.
shape
[
0
])
global_step
+=
1
dice
=
eval_net
(
net
,
val_loader
,
device
,
n_val
)
jac
=
eval_jac
(
net
,
val_loader
,
device
,
n_val
)
# overall_acc, avg_per_class_acc, avg_jacc, avg_dice = eval_multi(net, val_loader, device, n_val)
scheduler
.
step
(
dice
)
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
logging
.
info
(
f
'Avg Dice:{dice} Jaccard:{jac}
\n
'
f
'Learning Rate:{scheduler.get_lr()[0]}'
)
logging
.
info
(
f
'Avg Dice:{dice}
\t
'
f
'Learning Rate:{lr}'
)
writer
.
add_scalar
(
'Dice/test'
,
dice
,
global_step
)
writer
.
add_images
(
'images'
,
imgs
,
global_step
)
# if net.n_classes == 1:
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
masks_pred
>
0.5
,
global_step
)
if
epoch
%
5
==
0
:
try
:
os
.
mkdir
(
dir_checkpoint
)
...
...
resCalc.py
View file @
447cf1b5
...
...
@@ -8,14 +8,15 @@ import logging
import
os
import
re
def
save_img_mask
(
img
,
mask
,
dir
,
name
):
def
save_img_mask
(
img
,
mask
,
dir
,
name
):
plt
.
figure
(
dpi
=
300
)
plt
.
suptitle
(
name
)
plt
.
imshow
(
img
,
'gray'
)
#print(img.shape)
plt
.
imshow
(
img
,
'gray'
)
#
print(img.shape)
mask
=
cv
.
cvtColor
(
mask
,
cv
.
COLOR_GRAY2RGB
)
mask
[:,
:,
2
]
=
0
mask
[:,
:,
0
]
=
0
mask
[:,
:,
2
]
=
0
mask
[:,
:,
0
]
=
0
plt
.
imshow
(
mask
,
alpha
=
0.25
,
cmap
=
'rainbow'
)
try
:
os
.
makedirs
(
'data/output/'
+
dir
)
...
...
@@ -50,14 +51,15 @@ def get_subarea_info(img, mask):
group
=
np
.
where
(
labels
==
i
)
area_size
=
len
(
group
[
0
])
#if area_size > 10: # 过小的区域直接剔除
#
if area_size > 10: # 过小的区域直接剔除
area_value
=
img
[
group
]
area_mean
=
np
.
mean
(
area_value
)
# Background Value
pos
=
[(
group
[
0
][
k
],
group
[
1
][
k
])
for
k
in
range
(
len
(
group
[
0
]))]
area_points
=
np
.
array
([
mask
[
x
,
y
]
if
(
x
,
y
)
in
pos
else
0
for
x
in
range
(
200
)
for
y
in
range
(
200
)],
dtype
=
np
.
uint8
)
.
reshape
([
200
,
200
])
area_points
=
np
.
array
([
mask
[
x
,
y
]
if
(
x
,
y
)
in
pos
else
0
for
x
in
range
(
200
)
for
y
in
range
(
200
)],
dtype
=
np
.
uint8
)
.
reshape
([
200
,
200
])
kernel
=
np
.
ones
((
15
,
15
),
np
.
uint8
)
bg_area_mask
=
cv
.
erode
(
area_points
,
kernel
)
surround_bg_mask
=
cv
.
bitwise_xor
(
bg_area_mask
,
255
-
area_points
)
...
...
@@ -65,8 +67,7 @@ def get_subarea_info(img, mask):
back_value
=
img
[
np
.
where
(
real_bg_mask
!=
0
)]
back_mean
=
np
.
mean
(
back_value
)
info
.
append
({
'mean'
:
area_mean
,
'back'
:
back_mean
,
'size'
:
area_size
})
# endif
info
.
append
({
'mean'
:
area_mean
,
'back'
:
back_mean
,
'size'
:
area_size
})
# endif
df
=
pd
.
DataFrame
(
info
)
median
=
np
.
median
(
df
[
'mean'
])
...
...
@@ -77,7 +78,7 @@ def get_subarea_info(img, mask):
df
=
df
[
df
[
'mean'
]
>=
lower_limit
]
df
=
df
[
df
[
'mean'
]
<=
upper_limit
]
df
[
'value'
]
=
df
[
'mean'
]
-
df
[
'back'
]
df
[
'value'
]
=
df
[
'mean'
]
-
df
[
'back'
]
return
(
df
[
'value'
]
*
df
[
'size'
])
.
sum
()
/
df
[
'size'
]
.
sum
(),
df
.
shape
[
0
]
...
...
@@ -97,31 +98,17 @@ def get_subarea_info_avgBG(img, mask):
area_mean
=
np
.
mean
(
area_value
)
size
+=
area_size
value
+=
(
area_mean
-
bg
)
*
area_size
value
+=
(
area_mean
-
bg
)
*
area_size
return
value
/
size
def
get_subarea_info_fast
(
img
,
mask
):
def
get_subarea_info_fast
_outlier
(
img
,
mask
):
if
mask
.
max
()
==
0
:
return
None
else
:
kernel
=
np
.
ones
((
15
,
15
),
np
.
uint8
)
bg_area_mask
=
cv
.
dilate
(
mask
,
kernel
)
surround_bg_mask
=
cv
.
bitwise_xor
(
bg_area_mask
,
mask
)
sig_value
=
np
.
mean
(
img
[
np
.
where
(
mask
!=
0
)])
back_value
=
np
.
mean
(
img
[
np
.
where
(
surround_bg_mask
!=
0
)])
return
sig_value
-
back_value
def
get_subarea_info_fast_outlier
(
img
,
mask
):
if
mask
.
max
()
==
0
:
return
None
else
:
kernel
=
np
.
ones
((
15
,
15
),
np
.
uint8
)
bg_area_mask
=
cv
.
dilate
(
mask
,
kernel
)
surround_bg_mask
=
cv
.
bitwise_xor
(
bg_area_mask
,
mask
)
sig_mean
=
np
.
mean
(
img
[
np
.
where
(
mask
!=
0
)])
back_value
=
np
.
mean
(
img
[
np
.
where
(
surround_bg_mask
!=
0
)])
median
=
np
.
median
(
img
[
np
.
where
(
mask
!=
0
)])
b
=
1.4826
...
...
@@ -131,14 +118,35 @@ def get_subarea_info_fast_outlier(img,mask):
bg_median
=
np
.
median
(
img
[
np
.
where
(
surround_bg_mask
!=
0
)])
bg_mad
=
b
*
np
.
median
(
np
.
abs
(
img
[
np
.
where
(
surround_bg_mask
!=
0
)]
-
bg_median
))
bg_lower_limit
=
bg_median
-
(
3
*
bg_mad
)
bg_upper_limit
=
bg_median
+
(
3
*
bg_mad
)
bg_lower_limit
=
bg_median
-
(
3
*
bg_mad
)
bg_upper_limit
=
bg_median
+
(
3
*
bg_mad
)
res
=
img
[
np
.
where
(
mask
!=
0
)]
res
=
res
[
res
>=
lower_limit
]
res
=
res
[
res
<=
upper_limit
]
res
=
res
[
res
>=
lower_limit
]
res
=
res
[
res
<=
upper_limit
]
bg
=
img
[
np
.
where
(
surround_bg_mask
!=
0
)]
bg
=
bg
[
bg
>=
bg_lower_limit
]
bg
=
bg
[
bg
<=
bg_upper_limit
]
return
np
.
mean
(
res
)
-
np
.
mean
(
bg
)
\ No newline at end of file
bg
=
bg
[
bg
>=
bg_lower_limit
]
bg
=
bg
[
bg
<=
bg_upper_limit
]
return
np
.
mean
(
res
)
-
np
.
mean
(
bg
)
def
get_subarea_info_nobg
(
img
,
mask
):
if
mask
.
max
()
==
0
:
return
None
else
:
median
=
np
.
median
(
img
[
np
.
where
(
mask
!=
0
)])
b
=
1.4826
mad
=
b
*
np
.
median
(
np
.
abs
(
img
[
np
.
where
(
mask
!=
0
)]
-
median
))
lower_limit
=
median
-
(
3
*
mad
)
upper_limit
=
median
+
(
3
*
mad
)
res
=
img
[
np
.
where
(
mask
!=
0
)]
res
=
res
[
res
>=
lower_limit
]
res
=
res
[
res
<=
upper_limit
]
return
np
.
mean
(
res
)
# todo: def get_subarea_info_fast_outlier_cluster(img,mask):
train.py
View file @
447cf1b5
...
...
@@ -4,6 +4,7 @@ import os
import
sys
import
torch
from
torch
import
nn
import
unet
import
mrnet
...
...
@@ -49,9 +50,12 @@ if __name__ == '__main__':
#net = UNet(n_channels = 1, n_classes = 1)
net
=
MultiUnet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
f
'Network:
\n
'
f
'
\t
{net.n_channels} input channels
\n
'
f
'
\t
{net.n_classes} output channels (classes)
\n
'
)
if
torch
.
cuda
.
device_count
()
>
1
:
net
=
nn
.
DataParallel
(
net
)
# logging.info(f'Network:\n'
# f'\t{net.n_channels} input channels\n'
# f'\t{net.n_classes} output channels (classes)\n')
if
args
.
load
:
net
.
load_state_dict
(
torch
.
load
(
args
.
load
,
map_location
=
device
))
...
...
@@ -62,7 +66,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
try
:
u
net
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
)
mr
net
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
)
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
logging
.
info
(
'Saved interrupt'
)
...
...
unet/train.py
View file @
447cf1b5
...
...
@@ -17,6 +17,8 @@ from utils.eval import eval_net
from
torch.utils.tensorboard
import
SummaryWriter
from
utils.dataset
import
BasicDataset
,
VOCSegmentation
from
utils.dice_loss
import
DiceLoss
from
utils.focal_loss
import
FocalLoss
from
torch.utils.data
import
DataLoader
,
random_split
dir_img
=
'data/train_imgs/'
...
...
@@ -48,11 +50,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
# optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay = 1e-8)
optimizer
=
RMSprop
(
net
.
parameters
(),
lr
=
lr
,
weight_decay
=
1e-8
,
momentum
=
0.99
)
scheduler
=
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'max'
,
eps
=
1e-20
,
factor
=
0.
5
,
patience
=
5
)
if
net
.
n_classes
>
1
:
criterion
=
nn
.
CrossEntropyLoss
()
else
:
criterion
=
nn
.
BCEWithLogitsLoss
()
scheduler
=
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
'max'
,
eps
=
1e-20
,
factor
=
0.
1
,
patience
=
8
)
#
if net.n_classes > 1:
#
criterion = nn.CrossEntropyLoss()
#
else:
criterion
=
DiceLoss
()
#FocalLoss(alpha = .75, gamma = 2,logits = True)#
nn.BCEWithLogitsLoss()
for
epoch
in
range
(
epochs
):
net
.
train
()
...
...
@@ -60,11 +62,11 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
with
tqdm
(
total
=
n_train
,
desc
=
f
'Epoch {epoch + 1}/{epochs}'
,
unit
=
'img'
)
as
pbar
:
for
imgs
,
true_masks
in
train_loader
:
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
mask_type
=
torch
.
float32
#
if net.n_classes == 1 else torch.long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
masks_pred
=
net
(
imgs
)
loss
=
criterion
(
masks_pred
,
true_masks
)
loss
=
criterion
(
torch
.
sigmoid
(
masks_pred
)
,
true_masks
)
epoch_loss
+=
loss
.
item
()
writer
.
add_scalar
(
'Loss/train'
,
loss
.
item
(),
global_step
)
...
...
@@ -78,19 +80,19 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, save_cp = True)
val_score
=
eval_net
(
net
,
val_loader
,
device
,
n_val
)
scheduler
.
step
(
val_score
)
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
if
net
.
n_classes
>
1
:
logging
.
info
(
'Validation cross entropy: {}'
.
format
(
val_score
))
writer
.
add_scalar
(
'Loss/test'
,
val_score
,
global_step
)
else
:
logging
.
info
(
'Validation Dice Coeff: {} lr:{}'
.
format
(
val_score
,
lr
))
writer
.
add_scalar
(
'Dice/test'
,
val_score
,
global_step
)
#
if net.n_classes > 1:
#
logging.info('Validation cross entropy: {}'.format(val_score))
#
writer.add_scalar('Loss/test', val_score, global_step)
#
else:
logging
.
info
(
'Validation Dice Coeff: {} lr:{}'
.
format
(
val_score
,
lr
))
writer
.
add_scalar
(
'Dice/test'
,
val_score
,
global_step
)
writer
.
add_images
(
'images'
,
imgs
,
global_step
)
if
net
.
n_classes
==
1
:
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
torch
.
sigmoid
(
masks_pred
)
>
0.5
,
global_step
)
#
if net.n_classes == 1:
writer
.
add_images
(
'masks/true'
,
true_masks
,
global_step
)
writer
.
add_images
(
'masks/pred'
,
torch
.
sigmoid
(
masks_pred
)
>
0.5
,
global_step
)
if
save_cp
and
epoch
%
5
0
==
0
:
if
save_cp
and
(
epoch
+
1
)
%
1
0
==
0
:
try
:
os
.
mkdir
(
dir_checkpoint
)
logging
.
info
(
'Created checkpoint directory'
)
...
...
utils/dice_loss.py
View file @
447cf1b5
import
torch
from
torch.autograd
import
Function
import
torch.nn
as
nn
class
DiceCoeff
(
Function
):
"""Dice coeff for individual examples"""
...
...
@@ -37,11 +38,21 @@ def dice_coeff(input, target):
return
s
/
(
i
+
1
)
def
dice_coef
(
pred
,
target
):
smooth
=
1.
num
=
pred
.
size
(
0
)
m1
=
pred
.
view
(
num
,
-
1
)
# Flatten
m2
=
target
.
view
(
num
,
-
1
)
# Flatten
intersection
=
(
m1
*
m2
)
.
sum
()
class
DiceLoss
(
nn
.
Module
):
def
__init__
(
self
):
super
(
DiceLoss
,
self
)
.
__init__
()
def
forward
(
self
,
input
,
target
):
N
=
target
.
size
(
0
)
smooth
=
1
input_flat
=
input
.
view
(
N
,
-
1
)
target_flat
=
target
.
view
(
N
,
-
1
)
intersection
=
input_flat
*
target_flat
loss
=
2
*
(
intersection
.
sum
(
1
)
+
smooth
)
/
(
input_flat
.
sum
(
1
)
+
target_flat
.
sum
(
1
)
+
smooth
)
loss
=
1
-
loss
.
sum
()
/
N
return
loss
return
(
2.
*
intersection
+
smooth
)
/
(
m1
.
sum
()
+
m2
.
sum
()
+
smooth
)
\ No newline at end of file
utils/eval.py
View file @
447cf1b5
...
...
@@ -4,7 +4,7 @@ from tqdm import tqdm
from
sklearn.metrics
import
jaccard_score
import
numpy
as
np
from
utils.dice_loss
import
dice_coeff
,
dice_coef
from
utils.dice_loss
import
dice_coeff
def
eval_net
(
net
,
loader
,
device
,
n_val
):
...
...
@@ -15,17 +15,17 @@ def eval_net(net, loader, device, n_val):
with
tqdm
(
total
=
n_val
,
desc
=
'Validation round'
,
unit
=
'img'
,
leave
=
False
)
as
pbar
:
for
imgs
,
true_masks
in
loader
:
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
mask_type
=
torch
.
float32
#
if net.n_classes == 1 else torch.long
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
mask_pred
=
net
(
imgs
)
for
true_mask
,
pred
in
zip
(
true_masks
,
mask_pred
):
pred
=
(
pred
>
0.5
)
.
float
()
if
net
.
n_classes
>
1
:
tot
+=
F
.
cross_entropy
(
pred
.
unsqueeze
(
dim
=
0
),
true_mask
.
unsqueeze
(
dim
=
0
))
.
item
()
else
:
tot
+=
dice_coe
f
(
pred
,
true_mask
.
squeeze
(
dim
=
1
))
.
item
()
#
if net.n_classes > 1:
#
tot += F.cross_entropy(pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
#
else:
tot
+=
dice_coef
f
(
pred
,
true_mask
.
squeeze
(
dim
=
1
))
.
item
()
pbar
.
update
(
imgs
.
shape
[
0
])
return
tot
/
n_val
...
...
@@ -37,7 +37,7 @@ def eval_jac(net, loader, device, n_val):
with
tqdm
(
total
=
n_val
,
desc
=
'Validation round'
,
unit
=
'img'
,
leave
=
False
)
as
pbar
:
for
imgs
,
true_masks
in
loader
:
imgs
=
imgs
.
to
(
device
=
device
,
dtype
=
torch
.
float32
)
mask_type
=
torch
.
float32
if
net
.
n_classes
==
1
else
torch
.
long
mask_type
=
torch
.
float32
true_masks
=
true_masks
.
to
(
device
=
device
,
dtype
=
mask_type
)
pred_masks
=
net
(
imgs
)
...
...
utils/focal_loss.py
0 → 100644
View file @
447cf1b5
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
FocalLoss
(
nn
.
Module
):
def
__init__
(
self
,
alpha
=
1
,
gamma
=
2
,
logits
=
False
,
reduce
=
True
):
super
(
FocalLoss
,
self
)
.
__init__
()
self
.
alpha
=
alpha
self
.
gamma
=
gamma
self
.
logits
=
logits
self
.
reduce
=
reduce
def
forward
(
self
,
inputs
,
targets
):
if
self
.
logits
:
BCE_loss
=
F
.
binary_cross_entropy_with_logits
(
inputs
,
targets
,
reduce
=
False
)
else
:
BCE_loss
=
F
.
binary_cross_entropy
(
inputs
,
targets
,
reduce
=
False
)
pt
=
torch
.
exp
(
-
BCE_loss
)
F_loss
=
self
.
alpha
*
(
1
-
pt
)
**
self
.
gamma
*
BCE_loss
if
self
.
reduce
:
return
torch
.
mean
(
F_loss
)
else
:
return
F_loss
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment