佬们,最近在做设计。代码逻辑是将一个大的玉米地图片进行分割,分割成小块区域,再对小块区域进行逐个检测。识别出玉米苗后进行标注以及数据统计。这是导师直接发给我的项目,导师说在他的老式电脑上可以跑通,但是在我的电脑上一直提示cuda内存溢出,我的是2060 应该是6G的内存。求佬们帮忙看看原因,真要碎了
下面这段是代码:
# #-*-coding:utf-8-*-
import sys
import _init_paths
import numpy as np
import scipy.io as sio
import os, sys, cv2
import argparse
import pdb
from tqdm import tqdm
from tools.methods import *
from PIL import Image
from settings import Paths,scales
from detect import detect
def start():
imfolderpath = opt.img_folder_path
respath = Paths['respath']
while 1:
unHandledImgs = fileImgs(imfolderpath)
if not unHandledImgs:
print ("No Image, scan again")
else:
for Dimg in tqdm(unHandledImgs):
print ("file: %s" % Dimg)
res_combine = opt.res_combine
t1 = time.time()
candidates = []
candidates_seeds = []
candidates_point = []
img_name = Dimg.split('/')[-1].split('.')[0]
print ('image: %s' % img_name, '\n', \
'Start time: ', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), '\n')
resfilePath = respath + "/" + Dimg.split('\\')[-1].split('.')[0]#-1表示从右取到第一个\的字符串,即文件名
smallpath = respath + "/" + Dimg.split('\\')[-1].split('.')[0] + "/small_cut"
if not os.path.exists(smallpath):
os.makedirs(smallpath)
image = cv2.imread(Dimg)
im_y,im_x = image.shape[:-1]
residuum_y = (im_y - scales['window_size']) % scales['slip_window_step']#Y方向分块余量
residuum_x = (im_x - scales['window_size']) % scales['slip_window_step']#X方向分块余量
#补刘整数个slip_window_step,黑边小图的来历(初始取0值)
im_black = np.zeros((im_y+(scales['slip_window_step'] - residuum_y),im_x + (scales['slip_window_step'] - residuum_x),3),dtype=np.uint8)
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
im_blank_pil = Image.fromarray(cv2.cvtColor(im_black, cv2.COLOR_BGR2RGB))
im_blank_pil.paste(image_pil, (0,0))
img = cv2.cvtColor(np.asarray(im_blank_pil),cv2.COLOR_RGB2BGR)
cv2.imwrite(resfilePath+'/resized_im.jpg',img)
image_copy = img.copy()
line_image = img.copy()
conf_thresh = opt.conf_thresh
#分小区操作
for i in tqdm(range(0, im_y + (scales['slip_window_step'] - residuum_y), scales['slip_window_step'])):
for j in tqdm(range(0,im_x + (scales['slip_window_step'] - residuum_x), scales['slip_window_step'])):
im_cut = img[i:scales['window_size'] + i,j:scales['window_size'] + j]
if im_cut.shape[0] != scales['window_size'] or im_cut.shape[1] != scales['window_size']:
pass
else:
cut_path = smallpath+"/"+str(i)+"_"+str(j)+".jpg"#以行列为名,存小区图片
cv2.imwrite(cut_path, im_cut)
candidates += detect(cut_path, j, i, conf_thresh)
#candidates:[0]:x1;[1]y1;[2]x2;[3]y2;[4]cls;[5]面积;[6]midpts_x;[7]midpts_y;[8]Leaf_length
print("this is done",candidates)
for i,k in enumerate(candidates):
if k[4] == '1':
candidates_seeds.append(k)
elif k[4] == '2':
candidates_point.append(k)
candidates_seed_before = adjRect(candidates_seeds,1)
candidates_point_before = adjRect(candidates_point,1)
# print("look2:", candidates)
candidates_point_list = get_points_list(candidates_point_before, im_x/2)
print("look:", candidates_point_list)
count_row = 1
res_all = []
count_all = 0
# section循环
# 逐区域处理
# 这段开始不一样了
for i,k in enumerate(candidates_point_list):
print("testtesttesttest")
orects = []
rects = []
rects_final = []
print("tjsoosdj ")
if i+1 == len(candidates_point_list):
print("break咯")
break
# print('len_cands:', len(candidates_point_list))
# print('\n', candidates_point_list)
#为什么有三维向量
#好像没执行
new_ordered_pts = [[candidates_point_list[i][0][0], candidates_point_list[i][0][1]],
[candidates_point_list[i][1][0], candidates_point_list[i][1][1]],
[candidates_point_list[i+1][0][0], candidates_point_list[i+1][0][1]],
[candidates_point_list[i+1][1][0], candidates_point_list[i+1][1][1]]]
print ("op: %s" % new_ordered_pts)
candidates_seed_after,k1,k2 = setSec(candidates_seed_before, new_ordered_pts)
rects = candidates_seed_after
#2019/12/20小区分割
minus = int(((candidates_point_list[i+1][1][1]+candidates_point_list[i+1][0][1])-
(candidates_point_list[i][1][1]+candidates_point_list[i][0][1]))/2)
print ("minus: %d" % minus)
ordered_rects,res_rects = detSec(new_ordered_pts,rects,k1,k2,count_row,minus,res_combine)
#苗集合
rects_final+=res_rects
#小区区域集合
orects+=ordered_rects
if res_combine > 1:
del_secs_num = len(orects)%(res_combine)
if del_secs_num > 0:
orects_new = orects[:-del_secs_num]
else:
orects_new = orects
orects_final = []
rects_final_final = []
idx_ori = 1
for idx in range(0,len(orects_new)-res_combine+1,res_combine):
orects_final.append([orects_new[idx][0],
orects_new[idx][1],
orects_new[idx+res_combine-1][2],
orects_new[idx+res_combine-1][3],
count_row,
idx_ori])
idx_ori += 1
for idx,value in enumerate(rects_final):
if int(value[5])%res_combine > 0:
if int(value[5])/res_combine + 1 > idx_ori - 1:
pass
else:
rects_final_final.append([value[0],
value[1],
value[2],
value[3],
value[4],
int(value[5])/res_combine + 1,
int(value[5]%res_combine)])
elif int(value[5])%res_combine == 0:
rects_final_final.append([value[0],
value[1],
value[2],
value[3],
value[4],
int(value[5])/res_combine,
int(value[5]%res_combine)])
rects_final = rects_final_final
else:
orects_final = orects
# print ("len orects: %d/ len orects_final: %d" % len(orects),len(orects_final))
for e in range(0, len(orects_final)):
line_image = cv2.rectangle(line_image,
(orects_final[e][0],orects_final[e][1]),
(orects_final[e][2],orects_final[e][3]),
(255, 0, 0),
2)
for e in range(0, len(rects_final)):
line_image = cv2.rectangle(line_image,
(int(rects_final[e][0]-30),int(rects_final[e][1]-30)),
(int(rects_final[e][0]+30),int(rects_final[e][1]+30)),
(0, 0, 255),
2)
for e in range(0,len(orects_final)):
line_image = cv2.putText(line_image,
str(orects_final[e][4])+","+str(orects_final[e][5]),
(int((orects_final[e][0]+orects_final[e][2])/2-75),int((orects_final[e][1]+orects_final[e][3])/2)),
cv2.FONT_HERSHEY_SIMPLEX,
2,
(255, 255, 0),
2)
count_row += 1
#小区数据分析
for idx in range(0,len(orects_final)):
pts = []
for pts_idx in range(0,len(rects_final)):
if rects_final[pts_idx][4] == orects_final[idx][4]:
if rects_final[pts_idx][5] == orects_final[idx][5]:
# print rects_final[pts_idx],"\n",orects_final[idx]
if rects_final[pts_idx] in pts:
continue
else:
pts.append(rects_final[pts_idx])
if pts == []:
continue
res_all.append(GetRes(res_combine,orects_final[idx],pts,minus,orects_final[idx][4],orects_final[idx][5]))
count_all += 1
# pdb.set_trace()
# print (res_all)
gen_csv(count_all, res_all, resfilePath, Dimg.split('\\')[-1].split('.')[0])
image = cv2.imread(Dimg)
cv2.imwrite(resfilePath + "/ori.jpg", image)
cv2.imwrite(resfilePath + "/res.jpg", line_image)
# cv2.imwrite(resfilePath + "/" + file.split('\\')[-1].split('.')[0] + ".jpg", image_copy)
t2 = time.time()
print ("end time: ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), "\n", \
"Using time: ", str(t2 - t1), "s", "\n")
#path = os.path.join(resfilePath, file.split('/')[-1].split('.')[0] + '.csv')
# for j in range(0,len(unHandledImgs)):
# try:
# os.remove(unHandledImgs[j])
# except:
# pdb.set_trace()
# print("1")
time.sleep(30)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--conf-thresh', type=float, default=0.5)
parser.add_argument('--res-combine', type=int, default=1)
parser.add_argument('--img-folder-path', type=str, default='result/unHandled')
opt = parser.parse_args()
print(opt)
start()
按我的理解应该是调用了detect函数进行逐个的处理,下面是detect函数
import argparse
import cv2
import pdb
from sys import platform
from models import * # set ONNX_EXPORT in models.py
from utils.datasets import *
from utils.utils import *
from PIL import Image
from settings import *
#2020/04/23修改
def detect(img_fn,add_x,add_y,conf=0.5):
res=[]
# if img_type == 'seeds':
img_size = Paras_common['img_size']
weights = Paras_common['weights']
conf_thres = conf
iou_thres = Paras_common['iou_thres']
name = Paras_common['names']
# elif img_type == 'point':
# img_size = Paras_point['img_size']
# weights = Paras_point['weights']
# conf_thres = Paras_point['conf_thres']
# iou_thres = Paras_point['iou_thres']
# name = Paras_point['names']
# Initialize
# device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
device = torch_utils.select_device('0')
# Initialize model
model = Darknet(Paras_common['cfg'], img_size)
# Load weights
model.load_state_dict(torch.load(weights, map_location=device)['model'])
# Eval mode
model.to(device).eval()
# Set Dataloader
vid_path, vid_writer = None, None
#从小区图片存放路径导入数据
dataset = LoadImages(img_fn, img_size=img_size)
# Get names and colors
names = load_classes(name)
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
# Run inference
t0 = time.time()
for path, img, im0s, vid_cap in dataset:
#count_aaa = int(path[-5])
img = torch.from_numpy(img).to(device)
img = img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
t1 = torch_utils.time_synchronized()
# print(img.shape)
pred = model(img)[0]
t2 = torch_utils.time_synchronized()
# Apply NMS
pred = non_max_suppression(pred, conf_thres=conf_thres, iou_thres=iou_thres, classes=Paras_common['classes'], agnostic=Paras_common['agnostic_nms'])
# Process detections
for i, det in enumerate(pred): # detections per image
s, im0 = '', im0s
#save_path = str(Path(out) / Path(p).name)
s += '%gx%g ' % img.shape[2:] # print string
if det is not None and len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += '%g %ss, ' % (n, names[int(c)]) # add to string
# Write results
for *xyxy, conf, cls in det:
label = '%s %.2f' % (names[int(cls)], conf)
x1 = int(xyxy[0]+add_x)
x2 = int(xyxy[2]+add_x)
y1 = int(xyxy[1]+add_y)
y2 = int(xyxy[3]+add_y)
# y1 = int(xyxy[1]+int(path[-5]-1)*1136)
# y2 = int(xyxy[3]+int(path[-5]-1)*1136)
midPts_x = int((x1+x2)/2)
midPts_y = int((y1+y2)/2)
area = (x2-x1)*(y2-y1)
ear_len = max((x2-x1),(y2-y1))
if area < 800:
continue
else:
if min((x2-x1),(y2-y1)) < 20:
continue
else:
if label[0] == '1':
res.append((x1, y1, x2, y2, label[0], area, midPts_x, midPts_y, ear_len))
elif label[0] == '2':
if 0.8<abs((x2-x1)/(y2-y1))<1.2:
res.append((x1, y1, x2, y2, label[0], area, midPts_x, midPts_y, ear_len))
else:
pass
# plot_one_box(xyxy, im0, label=label,draw_point=draw_point, color=colors[int(cls)])
# Print time (inference + NMS)
# print('%sDone. (%.3fs)' % (s, t2 - t1))
# print('Done. (%.3fs)' % (time.time() - t0))
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path')
parser.add_argument('--names', type=str, default='data/mqms.names', help='*.names path')
parser.add_argument('--weights', type=str, default='weights/best.pt', help='weights path')
parser.add_argument('--source', type=str, default='data/samples', help='source') # input file/folder, 0 for webcam
parser.add_argument('--output', type=str, default='output', help='output folder') # output folder
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float, default=0.2, help='IOU threshold for NMS')
parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)')
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
parser.add_argument('--view-img', action='store_true', help='display results')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--draw-point', action='store_true', help='dtaw mid point instead of box')
opt = parser.parse_args()
# print(opt)
with torch.no_grad():
detect()
通过打断点的方式应该是在调用detect的过程中内存溢出了。图片大小是5569*4173 ;10.2mb