由于本人对gradio框架不熟悉,需要在现有的页面中增加一个功能,描述如下,请给出完整代码。用chatgpt刷的就算了,我早用过了:
1、在gradio的网页端,增加关于lora_type('main_content','main_title','little_title')的下拉列表,根据页面的选择,确定lora_type的值
2、根据选择的lora_type的值,重新加载get_lora_model函数 ,实现_launch_demo获得的参数重新加载。
原代码如下:
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""A simple web interactive chat demo based on gradio."""
import tools
from datetime import datetime
tools.set_cache()
import os
from argparse import ArgumentParser
import gradio as gr
import mdtex2html
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from peft import PeftModel
DEFAULT_CKPT_PATH = r'D:\1'
lora_main_content = r'D:\2'
lora_main_title_path = r'D:\3'
lora_little_title_path = r'D:\4'
def time_now():
# 打印格式化时间
return time.strftime( '%Y-%m-%d %H:%M:%S', time.localtime( time.time() ) ) # 打印按指定格式排版的时间
def cost_time(time_start):
# 计算总消耗时间
time_end = time.time()
time_dif = time_end - time_start
print( f'总耗时:{round( time_dif, 2 )}s' )
# 打印现在的时间
print( f'现在时间为:{time_now()}' )
def _get_args():
parser = ArgumentParser()
parser.add_argument( "-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r" )
parser.add_argument( "--cpu-only", action="store_true", help="Run demo with CPU only" )
parser.add_argument( "--share", action="store_true", default=False,
help="Create a publicly shareable link for the interface." )
parser.add_argument( "--inbrowser", action="store_true", default=True,
help="Automatically launch the interface in a new tab on the default browser." )
parser.add_argument( "--server-port", type=int, default=8999,
help="Demo server port." )
# parser.add_argument("--server-name", type=str, default="127.0.0.1",
parser.add_argument( "--server-name", type=str, default="0.0.0.0",
help="Demo server name." )
args = parser.parse_args()
return args
def get_lora_model(model,lora_type):
global system, lora_path, max_new_tokens, lora_model
if lora_type == 'main_content':
lora_path = lora_main_content
system = "1"
max_new_tokens = 2000
times = 1
elif lora_type == 'main_title':
lora_path = lora_main_title_path
system = "2"
max_new_tokens = 50
times = 5
elif lora_type == 'little_title':
lora_path = lora_little_title_path
system = "3"
max_new_tokens = 50
times = 5
lora_model = PeftModel.from_pretrained(
model,
lora_path,
fp16=True,
trust_remote_code=True,
use_flash_attn=False,
device_map="cuda:0",
# adapter_name="eng_alpaca"
).eval()
return lora_model, system, max_new_tokens,times
def _load_model_tokenizer(args, lora_type=None):
tokenizer = AutoTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
if args.cpu_only:
device_map = "cpu"
else:
# device_map = "auto"
device_map = "cuda:0" # 设置为仅使用一号 GPU
model = AutoModelForCausalLM.from_pretrained(
args.checkpoint_path,
fp16=True,
device_map=device_map,
trust_remote_code=True, use_flash_attn=False
).eval()
config = GenerationConfig.from_pretrained(
args.checkpoint_path, trust_remote_code=True
)
lora_model, system, max_new_tokens,times = get_lora_model( model, lora_type )
return lora_model, tokenizer, config, system,max_new_tokens,times
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate( y ):
y[i] = (
None if message is None else mdtex2html.convert( message ),
None if response is None else mdtex2html.convert( response ),
)
return y
gr.Chatbot.postprocess = postprocess
def _parse_text(text):
lines = text.split( "\n" )
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate( lines ):
if "```" in line:
count += 1
items = line.split( "`" )
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace( "`", r"\`" )
line = line.replace( "<", "<" )
line = line.replace( ">", ">" )
line = line.replace( " ", " " )
line = line.replace( "*", "*" )
line = line.replace( "_", "_" )
line = line.replace( "-", "-" )
line = line.replace( ".", "." )
line = line.replace( "!", "!" )
line = line.replace( "(", "(" )
line = line.replace( ")", ")" )
line = line.replace( "$", "$" )
lines[i] = "<br>" + line
text = "".join( lines )
return text
def _launch_demo(args, model, tokenizer, config, system, max_new_tokens,times):
max_tokens = gr.State( max_new_tokens ) # Initialize max tokens to 512
def predict(_query, _chatbot, _task_history):
print( "User Input:", _query )
_query = _parse_text( _query )
# 计时开始
time_start = time.time() # 开始计时
# print( f"User: {_query}" )
_chatbot.append( (_query, "") )
full_response = ""
generated_tokens = 0
# _title_list = []
titles = ''
for i in range( times ):
for response in model.chat_stream( tokenizer, _query,
max_new_tokens=max_new_tokens,
history=_task_history,
system=system,
generation_config=config,
):
# _chatbot[-1] = (_query, _parse_text( response ))
# yield _chatbot
full_response = _parse_text( response )
if len( full_response ) >= max_tokens.value:
print( f"Reached max tokens ({max_tokens.value}): Stopping generation." )
break
# _title_list.append( full_response )
titles += full_response + '<br>' # gradio 不支持“\n”
# 统一到一起输出
_chatbot.append( (_query, titles) )
yield _chatbot[1:]
# print( f"History: {_task_history}" )
print( f"system: {system}" )
_task_history.append( (_query, full_response) )
# print( f"Qwen-Chat: {_parse_text( full_response )}" )
print( f"Qwen-Chat: {titles}" )
print( f"输出{len( _parse_text( full_response ) )}个字符" )
cost_time( time_start )
print( "Done".center( 50, "*" ) )
def regenerate(_chatbot, _task_history):
if not _task_history:
yield _chatbot
return
item = _task_history.pop( -1 )
_chatbot.pop( -1 )
yield from predict( item[0], _chatbot, _task_history )
def continue_chat(_chatbot, _task_history):
'''
点击继续
'''
if not _task_history:
yield _chatbot
return
yield from predict( "继续", _chatbot, _task_history )
def reset_user_input():
return gr.update( value="" )
def reset_state(_chatbot, _task_history):
_task_history.clear()
_chatbot.clear()
import gc
gc.collect()
torch.cuda.empty_cache()
return _chatbot
with gr.Blocks() as demo:
gr.Markdown( """""" )
gr.Markdown(
"""\
<center><font size=5>大模型</center>""" )
gr.Markdown( """""" )
chatbot = gr.Chatbot( label='大模型', elem_classes="control-height", height=500 )
query = gr.Textbox( lines=2, label='输入' )
task_history = gr.State( [] )
with gr.Row():
submit_btn = gr.Button( "🚀 Submit (发送)" )
empty_btn = gr.Button( "🧹 Clear History (清除历史)" )
# continue_btn = gr.Button( "🤔 Continue (继续)" )
regen_btn = gr.Button( "️🔄 Regenerate (重试)" )
submit_btn.click( predict, [query, chatbot, task_history], [chatbot], show_progress=True )
submit_btn.click( reset_user_input, [], [query] )
empty_btn.click( reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True )
regen_btn.click( regenerate, [chatbot, task_history], [chatbot], show_progress=True )
# continue_btn.click( continue_chat, [chatbot, task_history], [chatbot], show_progress=True )
demo.queue().launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def main(lora_type):
args = _get_args()
model, tokenizer, config, system, max_new_tokens,times = _load_model_tokenizer( args, lora_type )
_launch_demo( args, model, tokenizer, config, system,max_new_tokens ,times)
if __name__ == '__main__':
lora_type = 'main_title'
# lora_type = 'little_title'
# lora_type = 'main_content'
main( lora_type )