我写了一个用hf接口进行llama2和llama3模型推理测速的脚本,测试出来结果llama2-7b比llama3-8b要更快,请问大家是怎么样测试的呢?测试结果有没有什么不一样?脚本代码如下,代码是在4090单卡上运行的。我的测试代码有没有什么问题呢?欢迎大家指出错误或者给我一些提议。
MyDataset.py
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, dummy=0):
self.dummy = dummy
def __len__(self):
if self.dummy == 0:
return 51
else:
return self.dummy
def __getitem__(self, i):
return "Please introduce yourself in detail."
main.py
import transformers
import torch
import argparse
from LiUtils import clock
from MyDataSet import MyDataset
from tqdm.auto import tqdm
model_load_time = 0.0
def get_argparse():
parser = argparse.ArgumentParser(description="运行llama3推理的脚本")
parser.add_argument("--dt", type=int, default=1, help="模型load的dtype")
parser.add_argument("--v", type=str, default=3, help="2+,3+,3")
parser.add_argument("--bs", type=int, default=3, help="batch_size")
return parser.parse_args()
model_dict = {
"3+": "/data/models/llama3-8b-instruct",
"3": "/data/models/llama3-8B-hf",
"2+": "/data/models/llama2-7b-instruct"
}
# def setGPUDevice(device_list):
# deviceSentence = ",".join([str(i) for i in device_list])
# subprocess.run(["export", f"CUDA_VISIBLE_DEVICES={deviceSentence}"])
def load_model_pipeline(dtype, llama_version):
"""
加载模型和分词器,返回一个pipeline
:param dtype:
:return:
"""
if llama_version not in model_dict.keys():
print("wrong keys! please pass" + model_dict.keys())
exit(123)
model_dir = model_dict[llama_version]
print(f"load model path: {model_dir}")
tokenizer_dir = model_dir
generator = transformers.pipeline(
"text-generation", model=model_dir,
device_map="auto",
torch_dtype=dtype,
tokenizer=tokenizer_dir,
config=f"{model_dir}"
)
print(generator.model)
return generator
def load_dataset():
return MyDataset()
def inference_batching(generator, dataset, batch_size):
"""
推理函数
:param batch_size:
:param dataset:
:param generator:
:return:
"""
dummy_dataset = MyDataset(10)
generator(dummy_dataset, batch_size=1, max_length=128, do_sample=False, truncation=True)
print("GPU hotted!")
with clock("Prefill and decode time"):
print(f"Streaming batch_size={batch_size}")
for out in tqdm(generator(dataset, batch_size=batch_size, max_length=128, do_sample=False,truncation=True), total=len(dataset)):
print(out)
def main():
# 分析参数
parser = get_argparse()
batch_size = parser.bs
llama_version = parser.v
dtype = torch.float16
if parser.dt == 0:
dtype = torch.float32
elif parser.dt == 1:
dtype = torch.float16
elif parser.dt == 2:
dtype = torch.bfloat16
elif parser.dt == 3:
dtype = "auto"
# load 模型
with clock("Load time (include tokenizer and model)"):
generator = load_model_pipeline(dtype, llama_version)
# load数据集
dataset = load_dataset()
# 批处理推理
inference_batching(generator, dataset, batch_size)
exit(123)
if __name__ == '__main__':
main()