childhood cat 2024-03-28 15:03 采纳率: 40%
浏览 9

warnings.warn('Falling back to the old format < 1.6. This support will be '

复现PyDeepFakeDet项目运行extract_faces.py出现UserWarning: Falling back to the old format < 1.6. This support will be deprecated in favor of default zipfile format introduced in 1.6. Please redo torch.save() to save it in the new zipfile format.
warnings.warn('Falling back to the old format < 1.6. This support will be '
请问要怎么解决这个问题啊?
我怀疑是solve函数里的model = get_model("resnet50_2020-07-20", max_size=1024, device=device)出的问题


import argparse
import os
import pickle
import random
from multiprocessing import Process, Queue
from time import time

import cv2
import torch
from retinaface.pre_trained_models import get_model

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

parser = argparse.ArgumentParser('script', add_help=False)
parser.add_argument('--root_dir', type=str)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--process', default=1, type=int)
args = parser.parse_args()


def read_list(path):
    ls = []
    with open(path, 'r') as f:
        for line in f:
            ls.append(line.strip().split())
    return ls


def write_list(path, ls):
    with open(path, 'w') as w:
        for line in ls:
            #file.write("%s\n" % item)
            print(' '.join([str(ele) for ele in line]), file=w)


def gen_dirs(raw, new):
    if not os.path.exists(new):
        os.mkdir(new)
    for root, dirs, files in os.walk(raw):
        for dir in dirs:
            whole_path = os.path.join(root, dir)
            rel_path = os.path.relpath(whole_path, raw)
            new_path = os.path.join(new, rel_path)
            if not os.path.exists(new_path):
                os.mkdir(new_path)


def can_seg(img_path, save_path, model=None, scale=1.3):
    img = cv2.imread(img_path)
    h, w, c = img.shape
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    annotation = model.predict_jsons(
        img, confidence_threshold=0.3
    ) 
    if len(annotation[0]['bbox']) == 0:
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite('./tmp2/%d.jpg' % (random.randint(1, 100)), img)
        return False
    x1, y1, x2, y2 = annotation[0]['bbox']
    x1, y1, x2, y2 = list(
        map(
            int,
            [
                x1 - (x2 - x1) * (scale - 1) / 2,
                y1 - (y2 - y1) * (scale - 1) / 2,
                x2 + (x2 - x1) * (scale - 1) / 2,
                y2 + (y2 - y1) * (scale - 1) / 2,
            ],
        )
    )
    x1, y1 = max(0, x1), max(0, y1)
    x2, y2 = min(w - 1, x2), min(h - 1, y2)
    img_face = img[y1:y2, x1:x2, :]
    img_face = cv2.cvtColor(img_face, cv2.COLOR_RGB2BGR)
    cv2.imwrite(save_path, img_face)
    return True


def solve(process_id, raw_list):
    device = 'cuda:%d' % process_id
    model = get_model("resnet50_2020-07-20", max_size=1024, device=device)
    model.eval()
    new_list = []
    check = 20
    start_time = time()
    for i, line in enumerate(raw_list):
        raw_path = os.path.join(args.root_dir, line[0])
        save_path = os.path.join(
            args.save_dir, os.path.relpath(raw_path, args.root_dir)
        )
        try:
            #print("#")
            if 'tmp' not in raw_path:
                if os.path.exists(save_path) or can_seg(
                    raw_path, save_path, model
                ):
                    new_list.append(line)
        except Exception as e:
            print(e, raw_path)
            exit(0)
        if i % check == check - 1:
            ET = time() - start_time
            start_time = time()
            ETA = ET / check * (len(raw_list) - i - 1)
            print(
                "process:%d %d/%d ET:%.2fmin ETA:%.2fmin "
                % (process_id, i + 1, len(raw_list), ET / 60, ETA / 60)
            )
    with open('tmp/%d.pkl' % process_id, 'wb') as w:
        pickle.dump(new_list, w)


if __name__ == '__main__':
    if not os.path.exists('tmp'):
        os.mkdir('tmp')
    torch.multiprocessing.set_start_method('spawn')
    raw_list_path = os.path.join(args.root_dir, 'image_list.txt')
    #print("1")
    new_list_path = os.path.join(args.save_dir, 'image_list_extract.txt')
   # print("2")
    raw_list = read_list(raw_list_path)
   # print("4")
    new_list = []
    gen_dirs(args.root_dir, args.save_dir)
    #print("5")
    # multi-process
    num_process = args.process
    sub_raw_list = []
    n = len(raw_list)
    step = n // num_process
    j = 0
    random.shuffle(raw_list)
    for i in range(0, n, step):
        j += 1
        if j == num_process:
            sub_raw_list.append(raw_list[i:n])
            #print(sub_raw_list)
            break
        else:
            sub_raw_list.append(raw_list[i : i + step])
    process_list = []
    Q = Queue()
    for i, item in enumerate(sub_raw_list):
        cur_process = Process(target=solve, args=(i, item))
        process_list.append(cur_process)
    for process in process_list:
        process.start()
    for process in process_list:
        process.join()

    # merge
    print("merge")
    #print(num_process)
    sub_image_list = []
    for i in range(num_process):
        with open("tmp/%d.pkl" % i, 'rb') as f:
            #print(pickle.load(f))
            sub_image_list.append(pickle.load(f))
            #print("sub_image_list")
    image_list = []
    # print(sub_image_list)
    for ele in sub_image_list:
        for line in ele:
            image_list.append(line)
            #print(line)
    write_list(new_list_path, image_list)

  • 写回答

1条回答 默认 最新

  • 梦回阑珊 2024-03-28 16:10
    关注

    【以下回答由 chatGPT 生成】
    回复不易,麻烦关注下博主,谢谢!!!

    这个警告 UserWarning: Falling back to the old format < 1.6. This support will be deprecated in favor of default zipfile format introduced in 1.6. Please redo torch.save() to save it in the new zipfile format. 是由于您使用了 PyTorch 的 torch.save() 函数保存模型时,使用了旧的保存格式,而 PyTorch 1.6 版本引入了默认的 zipfile 格式,并且官方计划在未来版本中淘汰旧格式。因此,建议您按照警告提示的建议,重新使用 torch.save() 函数保存模型以支持新的 zipfile 格式。

    这里提供一个示例来说明如何使用 torch.save() 函数以新的 zipfile 格式保存模型:

    # 假设您的模型是 model,保存路径是 model_path
    torch.save(model.state_dict(), model_path)  # 保存模型权重
    
    
    

    如果您的模型有额外的信息需要保存(如优化器状态、额外的配置等),您可以创建一个字典,将需要保存的信息整合到字典中,然后一起保存:

    # 假设您需要保存模型权重、优化器状态和其他额外信息
    # 请根据实际情况修改下面的示例代码
    state = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'additional_info': additional_info,
    }
    
    torch.save(state, model_path)  # 保存模型和其他信息
    
    
    

    请将以上示例代码中的 model 替换为您实际使用的模型变量名,并将 model_path 替换为您希望保存模型的文件路径。这样保存的模型将采用新的 zipfile 格式,不会再出现上述警告

    评论

报告相同问题?

问题事件

  • 创建了问题 3月28日

悬赏问题

  • ¥20 wpf datagrid单元闪烁效果失灵
  • ¥15 券商软件上市公司信息获取问题
  • ¥100 ensp启动设备蓝屏,代码clock_watchdog_timeout
  • ¥15 Android studio AVD启动不了
  • ¥15 陆空双模式无人机怎么做
  • ¥15 想咨询点问题,与算法转换,负荷预测,数字孪生有关
  • ¥15 C#中的编译平台的区别影响
  • ¥15 软件供应链安全是跟可靠性有关还是跟安全性有关?
  • ¥15 电脑蓝屏logfilessrtsrttrail问题
  • ¥20 关于wordpress建站遇到的问题!(语言-php)(相关搜索:云服务器)