运行环境vs code,代码如下:
import cv2
import math
import numpy as np
import os
def tile_cut(in_img_path, out_img_path, valid_size, patch_size, is_GT):
"""
滑动窗口裁剪
Args:
in_img_path(str): 输入大图路径.
out_img_path (str): 输出的预测大图结果路径.
valid_size (int): 预测小图的有效区域大小.
patch_size (int): 预测小图的区域大小.
is_GT (bool): 是否是真值文件.
"""
img_file_name = os.path.basename(in_img_path)
img_name = img_file_name[:-4]
in_img = None
if is_GT:
in_img = cv2.imread(in_img_path, cv2.IMREAD_GRAYSCALE)
else:
in_img = cv2.imread(in_img_path)
if in_img is None:
print(in_img_path)
#print(in_img.shape) #(行,列、通道数)
# 确定渔网(256*256)左上角和右下角中心像素坐标
quotient_cols = math.floor(in_img.shape[1] / valid_size) #向下取整
residue_cols = in_img.shape[1] % valid_size #取余
residue_cols = 1 if residue_cols > 0 else 0
num_cols = quotient_cols + residue_cols
quotient_rows = math.floor(in_img.shape[0] / valid_size) #向下取整
residue_rows = in_img.shape[0] % valid_size #取余
residue_rows = 1 if residue_rows > 0 else 0
num_rows = quotient_rows + residue_rows
center_minx = valid_size // 2
center_miny = valid_size // 2
center_maxx = valid_size // 2 + valid_size * (num_cols - 1)
center_maxy = valid_size // 2 + valid_size * (num_rows - 1)
# 根据渔网(256*256)左上角和右下角中心像素坐标确定需要填充的四至距离
left_border = 0
right_border = 0
top_border = 0
bottom_border = 0
if (center_minx - patch_size // 2) < 0:
left_border = patch_size // 2 - center_minx
if (center_miny - patch_size // 2) < 0:
top_border = patch_size // 2 - center_miny
if (center_maxx + patch_size // 2) > in_img.shape[1]:
right_border = center_maxx + patch_size // 2 - in_img.shape[1]
if (center_maxy + patch_size // 2) > in_img.shape[0]:
bottom_border = center_maxy + patch_size // 2 - in_img.shape[0]
# 将四周填充行列
in_img_pad = cv2.copyMakeBorder(in_img, top_border, bottom_border, left_border, right_border, cv2.BORDER_REFLECT_101)
# 预测和拼接
final_logit = None
for i in range(0, num_cols):
for j in range(0, num_rows):
# 填充图坐标系下的中心像素坐标
center_x = valid_size // 2 + left_border + valid_size * i
center_y = valid_size // 2 + top_border + valid_size * j
# 判断中心像素坐标是否有误
if center_x - valid_size < 0:
print("center_x - valid_size < 0")
return None
if center_y - valid_size < 0:
print("center_y - valid_size < 0")
return None
if center_x + valid_size > in_img_pad.shape[1]:
print("center_x + valid_size > in_img_pad.shape[1]")
return None
if center_y + valid_size > in_img_pad.shape[0]:
print("center_y + valid_size > in_img_pad.shape[0]")
return None
# 预测图在填充图坐标系下的位置,并抠取需要预测的小图
left_top_x = center_x - valid_size
left_top_y = center_y - valid_size
right_bottom_x = center_x + valid_size
right_bottom_y = center_y + valid_size
cropped = in_img_pad[left_top_y:right_bottom_y, left_top_x:right_bottom_x]
cv2.imwrite(out_img_path + img_name + '_' + str(center_x) + '_' + str(center_x) + '.png', cropped)
if __name__ == '__main__':
out_img_path = ' '
in_img_dir = ' '
valid_size = 256
patch_size = 512
is_GT = False
in_img_list = os.listdir(in_img_dir)
for in_img_file in in_img_list:
if in_img_file[-3:] == 'png':
print(in_img_file)
im_img_path = in_img_dir + in_img_file
tile_cut(im_img_path, out_img_path, valid_size, patch_size, is_GT)