文档服务地址:http://47.92.0.57:3000/ 周报索引地址:http://47.92.0.57:3000/s/NruNXRYmV

Commit 1cab996a by 王肇一

generate mask avoid real signal

parent 9a523da8
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import os
import cv2 as cv
import argparse
def get_args():
parser = argparse.ArgumentParser(description = 'Identify targets from background by KMeans or Threshold',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i', '--input', metavar = 'I', type = str, default = './',
help = 'input_dir', dest = 'input')
parser.add_argument('-o','--output',metavar = 'O',type = str,default = './out/',help='output dir',dest = 'output')
return parser.parse_args()
args = get_args()
os.mkdir(args.output)
for name in os.listdir(args.input):
img = cv.imread(args.input+name,flags = cv.IMREAD_GRAYSCALE)
cv.imwrite(args.output+name,img)
\ No newline at end of file
......@@ -4,8 +4,10 @@ import numpy as np
import cv2 as cv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import os
import re
from cvBasedMethod.filters import fft_mask,butterworth
def remove_scratch(img):
......@@ -71,32 +73,47 @@ def save_img(img_list, dir, name):
plt.title(title)
plt.imshow(img, 'gray')
try:
os.makedirs('out/' + dir + '/graph')
os.makedirs('data/output/' + dir + '/graph')
except:
logging.info('Existing dir: out/' + dir + '/graph')
plt.savefig('out/' + dir + '/graph/' + name + '.png')
logging.info('Existing dir: data/output/' + dir )
plt.savefig('data/output/' + dir + '/graph/' + name[:-4] + '.png')
plt.close()
def calcRes(img, mask, dir = 'output', name = 'output'):
def calcRes(img, mask, dir = 'default_output', name = 'output'):
dic = get_subarea_infor(img, mask)
df = pd.DataFrame(dic)
try:
os.makedirs('out/' + dir + '/csv')
os.makedirs('data/output/' + dir + '/csv')
except:
logging.info('Existing dir: out/' + dir + '/csv')
df.to_csv('out/' + dir + '/csv/' + name + '.csv')
if len(df)!=0:
df.to_csv('data/output/' + dir + '/csv/' + name + '.csv',index = False)
def draw_bar(exName,names):
df = pd.DataFrame(columns = ('class', 'perc', 'Label', 'Area', 'Mean', 'Std', 'BackMean', 'BackStd'))
for name in names:
tmp = pd.read_csv('data/output/' + exName + '/csv/' + name)
match_group = re.match('.*\s([dD]2[oO]|[lL][bB]|.*ug).*\s(.+)\.csv', name)
tmp['perc'] = str.lower(match_group.group(1))[:-2] if str.lower(match_group.group(1)).endswith('ug') else str.lower(
match_group.group(1))
tmp['perc'].replace({'d2o':'0'},inplace = True)
tmp['class'] = str.lower(match_group.group(2))
df = df.append(tmp, ignore_index = True, sort = True)
df = df[df['Area']>19]
df['Pure'] = df['Mean'] - df['BackMean']
sns.set_style("darkgrid")
#sns.catplot(x = 'perc',y = 'Mean',hue = 'class',kind='bar',data = df)
sns.pairplot(df, vars = ['Area', 'Mean', 'perc','class'])
plt.show()
def get_subarea_infor(img, mask):
area_num, labels, stats, centroids = cv.connectedComponentsWithStats(mask)
label_group = []
area_group = []
mean_group = []
std_group = []
back_mean = []
back_std = []
info = []
for i in filter(lambda x: x != 0, range(area_num)):
group = np.where(labels == i)
......@@ -113,19 +130,16 @@ def get_subarea_infor(img, mask):
res[x, y] = mask[x, y]
else:
res[x, y] = 0
kernel = np.ones((17, 17), np.uint8)
mask_background = cv.erode(255 - res, kernel)
minimask = cv.bitwise_xor(mask_background, 255 - res)
realminimask = cv.bitwise_and(minimask, 255 - mask)
img_background = img[np.where(minimask != 0)]
img_background = img[np.where(realminimask != 0)]
mean_value = np.mean(img_background)
std_value = np.std(img_background)
label_group.append(i)
area_group.append(area_tmp)
mean_group.append(mean_tmp)
std_group.append(std_tmp)
back_mean.append(mean_value)
back_std.append(std_value)
return {'Label': label_group, 'Area': area_group, 'Mean': mean_group, 'Std': std_group, 'BackMean': back_mean,
'BackStd': back_std}
info.append({'Label': i, 'Area': area_tmp, 'Mean': mean_tmp, 'Std': std_tmp, 'BackMean': mean_value,
'BackStd': std_value})
return info
\ No newline at end of file
......@@ -24,7 +24,7 @@ def method_threshold(imglist, process = 8):
def method_newThreshold(imglist, process = 8):
pool = Pool(process)
pool.map(lambda x: threshold(x,'fft'), imglist)
pool.map(lambda x: threshold(x, 'fft'), imglist)
pool.close()
pool.join()
......@@ -32,8 +32,9 @@ def method_newThreshold(imglist, process = 8):
def get_args():
parser = argparse.ArgumentParser(description = 'Identify targets from background by KMeans or Threshold',
formatter_class = argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-m', '--method', metavar = 'M', type = int, default = 0,
help = '0 for KMeans; 1 for Threshold; 2 for newThreshold', dest = 'method')
parser.add_argument('-m', '--method', metavar = 'M', type = int, default = 3,
help = '0 for KMeans; 1 for Threshold; 2 for newThreshold; 3 for Unet; 4 for further process',
dest = 'method')
parser.add_argument('-c', '--core', metavar = 'C', type = int, default = 5, help = 'Num of cluster', dest = 'core')
parser.add_argument('-p', '--process', metavar = 'P', type = int, default = 8, help = 'Num of process',
dest = 'process')
......@@ -48,16 +49,18 @@ def get_args():
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
path = ['data/imgs/'+x for x in os.listdir('data/imgs')]
path = [(y, x) for y in filter(lambda x: x != '.DS_Store', os.listdir('data/imgs')) for x in filter(
lambda x: x.endswith('.tif') and not x.endswith('dc.tif') and not x.endswith('DC.tif') and not x.endswith(
'dc .tif'), os.listdir('data/imgs/' + y))]
if args.method == 0:
method_kmeans(path,args.core)
method_kmeans(path, args.core)
elif args.method == 1:
method_threshold(path,args.process)
method_threshold(path, args.process)
elif args.method == 2:
method_newThreshold(path,args.process)
method_newThreshold(path, args.process)
elif args.method == 3:
predict(path, ['data/output/imgs/'+name[10:] for name in path], args.load, args.scale, args.mask_threshold)
predict(path, args.load, args.scale, args.mask_threshold)
for exName in filter(lambda x: x != '.DS_Store', os.listdir('data/output')):
draw_bar(exName, os.listdir('data/output/' + exName + '/csv'))
......@@ -3,10 +3,13 @@ import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import cv2 as cv
from tqdm import tqdm
from torchvision import transforms
from unet import UNet
from utils.dataset import BasicDataset
from cvBasedMethod.util import save_img,calcRes
def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
......@@ -30,14 +33,7 @@ def predict_img(net, full_img, device, scale_factor = 1, out_threshold = 0.5):
return full_mask > out_threshold
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
def predict(img_name, outdir, model, scale, mask_threshold):
in_files = img_name
out_files = outdir
def predict(file_names, model, scale, mask_threshold):
net = UNet(n_channels = 1, n_classes = 1)
logging.info("Loading model {}".format(model))
......@@ -49,15 +45,12 @@ def predict(img_name, outdir, model, scale, mask_threshold):
logging.info("Model loaded !")
for i, fn in enumerate(in_files):
logging.info("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
for i, fn in enumerate(tqdm(file_names)):
logging.info("\nPredicting image {} ...".format(fn[0]+'/'+fn[1]))
img = Image.open('data/imgs/' + fn[0] + '/' + fn[1])
mask = predict_img(net = net, full_img = img, scale_factor = scale, out_threshold = mask_threshold,
device = device)
#out_fn = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i]))
\ No newline at end of file
result = (mask * 255).astype(np.uint8) # result.save(out_files[i]) # logging.info("Mask saved to {}".format(out_files[i]))
save_img({'ori':img,'mask':result},fn[0],fn[1])
calcRes(cv.cvtColor(np.asarray(img), cv.COLOR_RGB2BGR),result,fn[0],fn[1][:-4])
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment