Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
I
Im
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
文档服务地址:
http://47.92.0.57:3000/
周报索引地址:
http://47.92.0.57:3000/s/NruNXRYmV
Open sidebar
王肇一
Im
Commits
8c28bb60
Commit
8c28bb60
authored
Feb 05, 2020
by
王肇一
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mrnet now ready
parent
9b02e4db
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
20 additions
and
19 deletions
+20
-19
.gitignore
.gitignore
+2
-0
mrnet_module.py
mrnet/mrnet_module.py
+2
-2
mrnet_parts.py
mrnet/mrnet_parts.py
+6
-5
train.py
mrnet/train.py
+2
-2
train.py
train.py
+2
-3
dataset.py
utils/dataset.py
+6
-7
No files found.
.gitignore
View file @
8c28bb60
...
@@ -5,4 +5,6 @@ data/masks/*
...
@@ -5,4 +5,6 @@ data/masks/*
data/output/*
data/output/*
data/train_imgs/*
data/train_imgs/*
data/train_masks/*
data/train_masks/*
data/train_imgs_32/*
.ipynb_checkpoints/
.ipynb_checkpoints/
runs
mrnet/mrnet_module.py
View file @
8c28bb60
...
@@ -14,7 +14,7 @@ class MultiUnet(nn.Module):
...
@@ -14,7 +14,7 @@ class MultiUnet(nn.Module):
self
.
bilinear
=
bilinear
self
.
bilinear
=
bilinear
self
.
inconv
=
nn
.
Sequential
(
self
.
inconv
=
nn
.
Sequential
(
nn
.
Conv2d
(
n_channels
,
8
,
kernel_size
=
3
,
stride
=
1
,
padding
_mode
=
'same'
),
nn
.
Conv2d
(
n_channels
,
8
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
8
),
nn
.
BatchNorm2d
(
8
),
nn
.
ReLU
(
inplace
=
True
)
nn
.
ReLU
(
inplace
=
True
)
)
)
...
@@ -41,7 +41,7 @@ class MultiUnet(nn.Module):
...
@@ -41,7 +41,7 @@ class MultiUnet(nn.Module):
self
.
pool
=
nn
.
MaxPool2d
(
2
)
self
.
pool
=
nn
.
MaxPool2d
(
2
)
self
.
outconv
=
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
3
,
padding
_mode
=
'same'
)
self
.
outconv
=
nn
.
Conv2d
(
self
.
res9
.
outc
,
n_classes
,
kernel_size
=
3
,
padding
=
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
inconv
(
x
)
x
=
self
.
inconv
(
x
)
...
...
mrnet/mrnet_parts.py
View file @
8c28bb60
...
@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module):
...
@@ -16,8 +16,8 @@ class MultiResBlock(nn.Module):
def
conv
(
self
,
in_channel
,
out_channel
,
kernel_size
):
def
conv
(
self
,
in_channel
,
out_channel
,
kernel_size
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
1
,
nn
.
Conv2d
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
1
)
,
padding_mode
=
'same'
),
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
))
nn
.
BatchNorm2d
(
out_channel
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
shortcut
=
nn
.
Sequential
(
shortcut
=
nn
.
Sequential
(
...
@@ -41,9 +41,9 @@ class TransCompose(nn.Module):
...
@@ -41,9 +41,9 @@ class TransCompose(nn.Module):
self
.
inc
=
in_channels
self
.
inc
=
in_channels
self
.
outc
=
out_channels
self
.
outc
=
out_channels
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
3
,
stride
=
2
),
nn
.
ConvTranspose2d
(
in_channels
=
self
.
inc
,
out_channels
=
self
.
outc
,
kernel_size
=
2
,
stride
=
2
),
nn
.
BatchNorm2d
(
self
.
outc
)
nn
.
BatchNorm2d
(
self
.
outc
)
)(
x
)
)(
x
)
...
@@ -61,7 +61,8 @@ class ResPath(nn.Module):
...
@@ -61,7 +61,8 @@ class ResPath(nn.Module):
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
nn
.
BatchNorm2d
(
self
.
outc
))(
x
)
conv
=
nn
.
Sequential
(
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
))(
x
)
result
=
torch
.
add
(
conv
,
shortcut
)
result
=
torch
.
add
(
conv
,
shortcut
)
return
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
out_channels
))(
result
)
return
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
BatchNorm2d
(
out_channels
))(
result
)
...
...
mrnet/train.py
View file @
8c28bb60
...
@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/'
...
@@ -18,7 +18,7 @@ dir_mask = 'data/train_masks/'
dir_checkpoint
=
'checkpoint'
dir_checkpoint
=
'checkpoint'
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
save_cp
=
True
,
img_scale
=
0.5
):
def
train_net
(
net
,
device
,
epochs
=
5
,
batch_size
=
1
,
lr
=
0.1
,
val_percent
=
0.1
,
img_scale
=
0.5
):
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
dataset
=
BasicDataset
(
dir_img
,
dir_mask
,
img_scale
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_val
=
int
(
len
(
dataset
)
*
val_percent
)
n_train
=
len
(
dataset
)
-
n_val
n_train
=
len
(
dataset
)
-
n_val
...
@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
...
@@ -27,7 +27,7 @@ def train_net(net, device, epochs = 5, batch_size = 1, lr = 0.1, val_percent = 0
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
val_loader
=
DataLoader
(
val
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
8
,
pin_memory
=
True
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
criterion
=
nn
.
BCELoss
()
criterion
=
nn
.
BCE
WithLogits
Loss
()
for
epoch
in
tqdm
(
range
(
epochs
)):
for
epoch
in
tqdm
(
range
(
epochs
)):
net
.
train
()
net
.
train
()
...
...
train.py
View file @
8c28bb60
...
@@ -40,14 +40,13 @@ if __name__ == '__main__':
...
@@ -40,14 +40,13 @@ if __name__ == '__main__':
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
logging
.
info
(
f
'Using device {device}'
)
logging
.
info
(
f
'Using device {device}'
)
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
# - For N > 2 classes, use n_classes=N
#net = UNet(n_channels = 1, n_classes = 1)
#
net = UNet(n_channels = 1, n_classes = 1)
net
=
MultiUnet
(
n_channels
=
1
,
n_classes
=
1
)
net
=
MultiUnet
(
n_channels
=
1
,
n_classes
=
1
)
logging
.
info
(
f
'Network:
\n
'
logging
.
info
(
f
'Network:
\n
'
...
@@ -64,7 +63,7 @@ if __name__ == '__main__':
...
@@ -64,7 +63,7 @@ if __name__ == '__main__':
# cudnn.benchmark = True
# cudnn.benchmark = True
try
:
try
:
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
mrnet
.
train_net
(
net
=
net
,
device
=
device
,
epochs
=
args
.
epochs
,
batch_size
=
args
.
batchsize
,
lr
=
args
.
lr
,
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
img_scale
=
args
.
scale
,
val_percent
=
args
.
val
/
100
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
torch
.
save
(
net
.
state_dict
(),
'INTERRUPTED.pth'
)
...
...
utils/dataset.py
View file @
8c28bb60
...
@@ -23,11 +23,10 @@ class BasicDataset(Dataset):
...
@@ -23,11 +23,10 @@ class BasicDataset(Dataset):
return
len
(
self
.
ids
)
return
len
(
self
.
ids
)
@classmethod
@classmethod
def
preprocess
(
cls
,
pil_img
,
scale
):
def
preprocess
(
cls
,
pil_img
):
w
,
h
=
pil_img
.
size
#newW, newH = 256,256 #int(scale * w), int(scale * h)
newW
,
newH
=
256
,
256
#int(scale * w), int(scale * h)
#assert newW > 0 and newH > 0, 'Scale is too small'
assert
newW
>
0
and
newH
>
0
,
'Scale is too small'
pil_img
=
pil_img
.
resize
((
256
,
256
))
pil_img
=
pil_img
.
resize
((
newW
,
newH
))
img_nd
=
np
.
array
(
pil_img
)
img_nd
=
np
.
array
(
pil_img
)
...
@@ -56,7 +55,7 @@ class BasicDataset(Dataset):
...
@@ -56,7 +55,7 @@ class BasicDataset(Dataset):
assert
img
.
size
==
mask
.
size
,
\
assert
img
.
size
==
mask
.
size
,
\
f
'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
f
'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img
=
self
.
preprocess
(
img
,
self
.
scale
)
img
=
self
.
preprocess
(
img
)
mask
=
self
.
preprocess
(
mask
,
self
.
scale
)
mask
=
self
.
preprocess
(
mask
)
return
{
'image'
:
torch
.
from_numpy
(
img
),
'mask'
:
torch
.
from_numpy
(
mask
)}
return
{
'image'
:
torch
.
from_numpy
(
img
),
'mask'
:
torch
.
from_numpy
(
mask
)}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment