Sharkkkyu 2024-04-24 13:48 采纳率: 0%
浏览 3
已结题

求解 yolo算法问题

佬们,最近在做设计。代码逻辑是将一个大的玉米地图片进行分割,分割成小块区域,再对小块区域进行逐个检测。识别出玉米苗后进行标注以及数据统计。这是导师直接发给我的项目,导师说在他的老式电脑上可以跑通,但是在我的电脑上一直提示cuda内存溢出,我的是2060 应该是6G的内存。求佬们帮忙看看原因,真要碎了
下面这段是代码:

# #-*-coding:utf-8-*-
import sys
import _init_paths
import numpy as np
import scipy.io as sio
import os, sys, cv2
import argparse
import pdb
from tqdm import tqdm
from tools.methods import *
from PIL import Image
from settings import Paths,scales
from detect import detect
def start():
    imfolderpath = opt.img_folder_path
    respath = Paths['respath']

    while 1:
        unHandledImgs = fileImgs(imfolderpath)
        if not unHandledImgs:
            print ("No Image, scan again")
        else:
            for Dimg in tqdm(unHandledImgs):
                print ("file: %s" % Dimg)
                res_combine = opt.res_combine
                t1 = time.time()
                
                candidates = []
                candidates_seeds = []
                candidates_point = []

                img_name = Dimg.split('/')[-1].split('.')[0]
                print ('image: %s' % img_name, '\n', \
                    'Start time: ', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), '\n')

                resfilePath = respath + "/" + Dimg.split('\\')[-1].split('.')[0]#-1表示从右取到第一个\的字符串,即文件名
                smallpath = respath + "/" + Dimg.split('\\')[-1].split('.')[0] + "/small_cut"
                
                if not os.path.exists(smallpath):
                    os.makedirs(smallpath)

                image = cv2.imread(Dimg)
                im_y,im_x = image.shape[:-1]

                residuum_y = (im_y - scales['window_size']) % scales['slip_window_step']#Y方向分块余量
                residuum_x = (im_x - scales['window_size']) % scales['slip_window_step']#X方向分块余量
                #补刘整数个slip_window_step,黑边小图的来历(初始取0值)
                im_black = np.zeros((im_y+(scales['slip_window_step'] - residuum_y),im_x + (scales['slip_window_step'] - residuum_x),3),dtype=np.uint8)
                image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
                im_blank_pil = Image.fromarray(cv2.cvtColor(im_black, cv2.COLOR_BGR2RGB))
                im_blank_pil.paste(image_pil, (0,0))
                img = cv2.cvtColor(np.asarray(im_blank_pil),cv2.COLOR_RGB2BGR)
                cv2.imwrite(resfilePath+'/resized_im.jpg',img)
                
                image_copy = img.copy()
                line_image = img.copy()
                conf_thresh = opt.conf_thresh
                #分小区操作
                for i in tqdm(range(0, im_y + (scales['slip_window_step'] - residuum_y), scales['slip_window_step'])):
                    for j in tqdm(range(0,im_x + (scales['slip_window_step'] - residuum_x), scales['slip_window_step'])):
                        im_cut = img[i:scales['window_size'] + i,j:scales['window_size'] + j]
                        if im_cut.shape[0] != scales['window_size'] or im_cut.shape[1] != scales['window_size']:
                            pass
                        else:
                            cut_path = smallpath+"/"+str(i)+"_"+str(j)+".jpg"#以行列为名,存小区图片
                            cv2.imwrite(cut_path, im_cut)
                            candidates += detect(cut_path, j, i, conf_thresh)

                #candidates:[0]:x1;[1]y1;[2]x2;[3]y2;[4]cls;[5]面积;[6]midpts_x;[7]midpts_y;[8]Leaf_length
                print("this is done",candidates)
                for i,k in enumerate(candidates):
                    if k[4] == '1':
                        candidates_seeds.append(k)
                    elif k[4] == '2':
                        candidates_point.append(k)
                candidates_seed_before = adjRect(candidates_seeds,1)
                candidates_point_before = adjRect(candidates_point,1)

                # print("look2:", candidates)
                candidates_point_list = get_points_list(candidates_point_before, im_x/2)
                print("look:", candidates_point_list)

                count_row = 1

                res_all = [] 
                count_all = 0                                      
                
                # section循环
                # 逐区域处理
                # 这段开始不一样了
                for i,k in enumerate(candidates_point_list):
                    print("testtesttesttest")
                    orects = []
                    rects = []
                    rects_final = []
                    print("tjsoosdj ")
                    if i+1 == len(candidates_point_list):
                        print("break咯")
                        break
                    # print('len_cands:', len(candidates_point_list))
                    # print('\n', candidates_point_list)
#为什么有三维向量
#好像没执行
                    new_ordered_pts = [[candidates_point_list[i][0][0], candidates_point_list[i][0][1]],
                                        [candidates_point_list[i][1][0], candidates_point_list[i][1][1]],
                                        [candidates_point_list[i+1][0][0], candidates_point_list[i+1][0][1]],
                                        [candidates_point_list[i+1][1][0], candidates_point_list[i+1][1][1]]]
                    print ("op: %s" % new_ordered_pts)

                    candidates_seed_after,k1,k2 = setSec(candidates_seed_before, new_ordered_pts)
                    rects = candidates_seed_after

                    #2019/12/20小区分割
                    minus = int(((candidates_point_list[i+1][1][1]+candidates_point_list[i+1][0][1])-
                                    (candidates_point_list[i][1][1]+candidates_point_list[i][0][1]))/2)
                    print ("minus: %d" % minus)
                    ordered_rects,res_rects = detSec(new_ordered_pts,rects,k1,k2,count_row,minus,res_combine)

                    #苗集合
                    rects_final+=res_rects
                    #小区区域集合
                    orects+=ordered_rects

                    
                    if res_combine > 1:
                        del_secs_num = len(orects)%(res_combine)
                        if del_secs_num > 0:
                            orects_new = orects[:-del_secs_num]
                        else:
                            orects_new = orects 
                        orects_final = []
                        rects_final_final = []
                        idx_ori = 1
                        for idx in range(0,len(orects_new)-res_combine+1,res_combine):
                            orects_final.append([orects_new[idx][0],
                                                    orects_new[idx][1],
                                                    orects_new[idx+res_combine-1][2],
                                                    orects_new[idx+res_combine-1][3],
                                                    count_row,
                                                    idx_ori])
                            idx_ori += 1
                        for idx,value in enumerate(rects_final):
                            if int(value[5])%res_combine > 0:
                                if int(value[5])/res_combine + 1 > idx_ori - 1:
                                    pass
                                else:
                                    rects_final_final.append([value[0],
                                                                value[1],
                                                                value[2],
                                                                value[3],
                                                                value[4],
                                                                int(value[5])/res_combine + 1,
                                                                int(value[5]%res_combine)])
                            elif int(value[5])%res_combine == 0:
                                rects_final_final.append([value[0],
                                                            value[1],
                                                            value[2],
                                                            value[3],
                                                            value[4],
                                                            int(value[5])/res_combine,
                                                            int(value[5]%res_combine)])
                        rects_final = rects_final_final
                    else:
                        orects_final = orects

                    # print ("len orects: %d/ len orects_final: %d" % len(orects),len(orects_final))

                    for e in range(0, len(orects_final)):
                        line_image = cv2.rectangle(line_image,
                                                (orects_final[e][0],orects_final[e][1]),
                                                (orects_final[e][2],orects_final[e][3]),
                                                (255, 0, 0),
                                                2)
                            
                    
                    for e in range(0, len(rects_final)):
                        line_image = cv2.rectangle(line_image,
                                                   (int(rects_final[e][0]-30),int(rects_final[e][1]-30)),
                                                   (int(rects_final[e][0]+30),int(rects_final[e][1]+30)),
                                                   (0, 0, 255),
                                                   2)

                    for e in range(0,len(orects_final)):
                        line_image = cv2.putText(line_image,
                                                 str(orects_final[e][4])+","+str(orects_final[e][5]),
                                                 (int((orects_final[e][0]+orects_final[e][2])/2-75),int((orects_final[e][1]+orects_final[e][3])/2)),
                                                 cv2.FONT_HERSHEY_SIMPLEX,
                                                 2,
                                                 (255, 255, 0),
                                                 2)
                    
                    count_row += 1 
                    #小区数据分析
                
                    for idx in range(0,len(orects_final)):
                        pts = []
                        for pts_idx in range(0,len(rects_final)):
                            if rects_final[pts_idx][4] == orects_final[idx][4]:
                                if rects_final[pts_idx][5] == orects_final[idx][5]:
                                    # print rects_final[pts_idx],"\n",orects_final[idx]
                                    if rects_final[pts_idx] in pts:
                                        continue
                                    else:
                                        pts.append(rects_final[pts_idx])
                        if pts == []:
                            continue
                        res_all.append(GetRes(res_combine,orects_final[idx],pts,minus,orects_final[idx][4],orects_final[idx][5]))
                        count_all += 1
                    # pdb.set_trace()
                # print (res_all)
                gen_csv(count_all, res_all, resfilePath, Dimg.split('\\')[-1].split('.')[0])
                image = cv2.imread(Dimg)
                cv2.imwrite(resfilePath + "/ori.jpg", image)
                cv2.imwrite(resfilePath + "/res.jpg", line_image)
                # cv2.imwrite(resfilePath + "/" + file.split('\\')[-1].split('.')[0] + ".jpg", image_copy)
                t2 = time.time()

                print ("end time: ", time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), "\n", \
                      "Using time: ", str(t2 - t1), "s", "\n")

                #path = os.path.join(resfilePath, file.split('/')[-1].split('.')[0] + '.csv')


            # for j in range(0,len(unHandledImgs)):
            #     try:
            #         os.remove(unHandledImgs[j])
            #     except:
            #         pdb.set_trace()
                    # print("1")
        time.sleep(30)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--conf-thresh', type=float, default=0.5)
    parser.add_argument('--res-combine', type=int, default=1)  
    parser.add_argument('--img-folder-path', type=str, default='result/unHandled')
    opt = parser.parse_args()
    print(opt)
    start()

按我的理解应该是调用了detect函数进行逐个的处理,下面是detect函数

import argparse
import cv2
import pdb
from sys import platform
from models import *  # set ONNX_EXPORT in models.py
from utils.datasets import *
from utils.utils import *
from PIL import Image
from settings import *

#2020/04/23修改
def detect(img_fn,add_x,add_y,conf=0.5):
    res=[]
    # if img_type == 'seeds':
    img_size = Paras_common['img_size']
    weights = Paras_common['weights']
    conf_thres = conf
    iou_thres = Paras_common['iou_thres']
    name = Paras_common['names']

    # elif img_type == 'point':
    #     img_size = Paras_point['img_size']
    #     weights = Paras_point['weights']
    #     conf_thres = Paras_point['conf_thres']
    #     iou_thres = Paras_point['iou_thres']
    #     name = Paras_point['names']

    # Initialize
    # device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
    device = torch_utils.select_device('0')

    # Initialize model
    model = Darknet(Paras_common['cfg'], img_size)

    # Load weights
    model.load_state_dict(torch.load(weights, map_location=device)['model'])

    # Eval mode
    model.to(device).eval()

    # Set Dataloader
    vid_path, vid_writer = None, None
    #从小区图片存放路径导入数据
    dataset = LoadImages(img_fn, img_size=img_size)

    # Get names and colors
    names = load_classes(name)
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()
    for path, img, im0s, vid_cap in dataset:
        #count_aaa = int(path[-5])
        img = torch.from_numpy(img).to(device)
        img = img.float()  # uint8 to fp16/32
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        t1 = torch_utils.time_synchronized()

        # print(img.shape)

        pred = model(img)[0]
        t2 = torch_utils.time_synchronized()

        # Apply NMS
        pred = non_max_suppression(pred, conf_thres=conf_thres, iou_thres=iou_thres, classes=Paras_common['classes'], agnostic=Paras_common['agnostic_nms'])

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            s, im0 = '', im0s

            #save_path = str(Path(out) / Path(p).name)
            s += '%gx%g ' % img.shape[2:]  # print string
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string

                # Write results
                for *xyxy, conf, cls in det:
                    label = '%s %.2f' % (names[int(cls)], conf)
                    x1 = int(xyxy[0]+add_x)
                    x2 = int(xyxy[2]+add_x)
                    y1 = int(xyxy[1]+add_y)
                    y2 = int(xyxy[3]+add_y)
                    # y1 = int(xyxy[1]+int(path[-5]-1)*1136)
                    # y2 = int(xyxy[3]+int(path[-5]-1)*1136)
                    midPts_x = int((x1+x2)/2)
                    midPts_y = int((y1+y2)/2)
                    area = (x2-x1)*(y2-y1)
                    ear_len = max((x2-x1),(y2-y1))
                    if area < 800:
                        continue
                    else:
                        if min((x2-x1),(y2-y1)) < 20:
                            continue
                        else:
                            if label[0] == '1':
                                res.append((x1, y1, x2, y2, label[0], area, midPts_x, midPts_y, ear_len))
                            elif label[0] == '2':
                                if 0.8<abs((x2-x1)/(y2-y1))<1.2:
                                    res.append((x1, y1, x2, y2, label[0], area, midPts_x, midPts_y, ear_len))
                                else:
                                    pass
                            # plot_one_box(xyxy, im0, label=label,draw_point=draw_point, color=colors[int(cls)])

            # Print time (inference + NMS)
            # print('%sDone. (%.3fs)' % (s, t2 - t1))
    # print('Done. (%.3fs)' % (time.time() - t0))
    return res


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path')
    parser.add_argument('--names', type=str, default='data/mqms.names', help='*.names path')
    parser.add_argument('--weights', type=str, default='weights/best.pt', help='weights path')
    parser.add_argument('--source', type=str, default='data/samples', help='source')  # input file/folder, 0 for webcam
    parser.add_argument('--output', type=str, default='output', help='output folder')  # output folder
    parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.2, help='IOU threshold for NMS')
    parser.add_argument('--fourcc', type=str, default='mp4v', help='output video codec (verify ffmpeg support)')
    parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
    parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1) or cpu')
    parser.add_argument('--view-img', action='store_true', help='display results')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    parser.add_argument('--draw-point', action='store_true', help='dtaw mid point instead of box')
    opt = parser.parse_args()
    # print(opt)

    with torch.no_grad():
        detect()


通过打断点的方式应该是在调用detect的过程中内存溢出了。图片大小是5569*4173 ;10.2mb

  • 写回答

22条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2024-04-24 13:48
    关注
    获得0.45元问题酬金

    引自免费微信小程序:皆我百晓生

    在您的代码中,出现CUDA内存溢出的问题可能与以下几个方面有关:

    1. 模型参数量过大:YOLO模型通常包含大量的权重参数,特别是对于大型图像,加载整个模型可能会占用大量GPU显存。检查所使用的YOLO模型是否过于复杂,或者是否有不必要的额外层导致显存占用增加。

    2. 图像分辨率过高:您提到的玉米地图片尺寸为5569×4173像素,这是一个非常大的图像。即使使用较小的窗口进行分割,每个子区域的分辨率仍然可能较高。高分辨率图像在推理过程中会占用较多显存。尝试降低输入图像的分辨率或进一步增大分割窗口的步长(scales['slip_window_step']),以减小单个子区域的大小。

    3. 批量处理与显存管理:当前代码中,每次只对一个分割的小区域进行检测。尽管如此,由于没有显式设置批量大小,PyTorch可能默认一次性加载所有子区域的数据到显存。考虑调整批量处理策略,比如设置合理的批量大小,分批处理子区域,以降低单次推理所需的显存。

    4. 临时变量未释放:确保在推理过程中产生的中间结果(如张量)被及时释放,避免累积占用过多显存。检查detect()函数内部是否有未释放的张量或变量。

    针对上述问题,以下是对您的代码进行优化的建议:

    优化1:调整模型参数

    确认所使用的YOLO模型是否适合您的硬件资源。如果模型过于复杂,可以考虑使用轻量级的YOLO变种(如YOLOv3-Tiny、YOLOv4-Tiny等)进行替换。

    优化2:降低图像分辨率

    在分割前,先对原图进行降采样,减少其分辨率。例如,可以添加以下代码行,在读取图像时进行降采样:

    image = cv2.imread(Dimg)
    image = cv2.resize(image, (int(im_x / scale_factor), int(im_y / scale_factor)), interpolation=cv2.INTER_LINEAR)
    

    这里scale_factor是您选择的缩放比例,根据实际情况调整。

    优化3:分批次处理子区域

    start()函数中,将子区域的检测过程封装为一个可迭代的函数,并设定合适的批量大小进行分批处理。例如:

    BATCH_SIZE = 8
    
    def batch_detect(cut_paths, conf_thresh):
        for i in range(0, len(cut_paths), BATCH_SIZE):
            batch_paths = cut_paths[i:i+BATCH_SIZE]
            batch_res = []
            for path in batch_paths:
                batch_res.extend(detect(path, conf_thresh))
            yield batch_res
    
    ...
    
    batched_candidates = batch_detect(cut_paths, conf_thresh)
    for batch in batched_candidates:
        candidates.extend(batch)
    

    优化4:清理临时变量

    detect()函数内部,确保释放不再需要的张量。例如,使用torch.no_grad()上下文管理器来禁用梯度计算,减少显存占用。同时,显式释放不需要的张量,如predimg等。

    with torch.no_grad():
        ...
        pred = model(img)[0]
        ...
        del img, pred  #释放不再需要的张量
    

    综上所述,通过对模型、图像分辨率、批量处理策略以及临时变量管理等方面的优化,应能有效缓解CUDA内存溢出的问题。请根据您的具体需求和硬件资源,适当调整上述建议中的参数和策略。

    评论 编辑记录

报告相同问题?

问题事件

  • 系统已结题 5月2日
  • 创建了问题 4月24日

悬赏问题

  • ¥15 多址通信方式的抗噪声性能和系统容量对比
  • ¥15 winform的chart曲线生成时有凸起
  • ¥15 msix packaging tool打包问题
  • ¥15 finalshell节点的搭建代码和那个端口代码教程
  • ¥15 Centos / PETSc / PETGEM
  • ¥15 centos7.9 IPv6端口telnet和端口监控问题
  • ¥20 完全没有学习过GAN,看了CSDN的一篇文章,里面有代码但是完全不知道如何操作
  • ¥15 使用ue5插件narrative时如何切换关卡也保存叙事任务记录
  • ¥20 海浪数据 南海地区海况数据,波浪数据
  • ¥20 软件测试决策法疑问求解答