kynow2 2024-03-11 16:13 采纳率: 0%
浏览 136
已结题

python,修改gradio,增加下拉列表,实现实时重新加载函数

由于本人对gradio框架不熟悉,需要在现有的页面中增加一个功能,描述如下,请给出完整代码。用chatgpt刷的就算了,我早用过了:

1、在gradio的网页端,增加关于lora_type('main_content','main_title','little_title')的下拉列表,根据页面的选择,确定lora_type的值
2、根据选择的lora_type的值,重新加载get_lora_model函数 ,实现_launch_demo获得的参数重新加载。

原代码如下:


# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""A simple web interactive chat demo based on gradio."""

import tools
from datetime import datetime

tools.set_cache()

import os
from argparse import ArgumentParser

import gradio as gr
import mdtex2html
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
from peft import PeftModel


DEFAULT_CKPT_PATH = r'D:\1'
lora_main_content = r'D:\2'
lora_main_title_path = r'D:\3'
lora_little_title_path = r'D:\4'


def time_now():
    # 打印格式化时间
    return time.strftime( '%Y-%m-%d %H:%M:%S', time.localtime( time.time() ) )  # 打印按指定格式排版的时间


def cost_time(time_start):
    # 计算总消耗时间

    time_end = time.time()
    time_dif = time_end - time_start

    print( f'总耗时:{round( time_dif, 2 )}s' )

    # 打印现在的时间
    print( f'现在时间为:{time_now()}' )


def _get_args():
    parser = ArgumentParser()
    parser.add_argument( "-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
                         help="Checkpoint name or path, default to %(default)r" )

    parser.add_argument( "--cpu-only", action="store_true", help="Run demo with CPU only" )

    parser.add_argument( "--share", action="store_true", default=False,
                         help="Create a publicly shareable link for the interface." )
    parser.add_argument( "--inbrowser", action="store_true", default=True,
                         help="Automatically launch the interface in a new tab on the default browser." )
    parser.add_argument( "--server-port", type=int, default=8999,
                         help="Demo server port." )
    # parser.add_argument("--server-name", type=str, default="127.0.0.1",
    parser.add_argument( "--server-name", type=str, default="0.0.0.0",
                         help="Demo server name." )

    args = parser.parse_args()
    return args


def get_lora_model(model,lora_type):
    global system, lora_path, max_new_tokens, lora_model

    if lora_type == 'main_content':
        lora_path = lora_main_content
        system = "1"  
        max_new_tokens = 2000
        times = 1
    elif lora_type == 'main_title':
        lora_path = lora_main_title_path
        system = "2"  
        max_new_tokens = 50
        times = 5
    elif lora_type == 'little_title':
        lora_path = lora_little_title_path
        system = "3"  
        max_new_tokens = 50
        times = 5

    lora_model = PeftModel.from_pretrained(
        model,
        lora_path,
        fp16=True,
        trust_remote_code=True,
        use_flash_attn=False,
        device_map="cuda:0",
        # adapter_name="eng_alpaca"
        ).eval()

    return lora_model, system, max_new_tokens,times




def _load_model_tokenizer(args, lora_type=None):
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path, trust_remote_code=True
    )

    if args.cpu_only:
        device_map = "cpu"
    else:
        # device_map = "auto"
        device_map = "cuda:0"  # 设置为仅使用一号 GPU

    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path,
        fp16=True,
        device_map=device_map,
        trust_remote_code=True, use_flash_attn=False

    ).eval()

    config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True
    )

    lora_model, system, max_new_tokens,times = get_lora_model( model, lora_type )

    return lora_model, tokenizer, config, system,max_new_tokens,times


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate( y ):
        y[i] = (
            None if message is None else mdtex2html.convert( message ),
            None if response is None else mdtex2html.convert( response ),
        )
    return y


gr.Chatbot.postprocess = postprocess


def _parse_text(text):
    lines = text.split( "\n" )
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate( lines ):
        if "```" in line:
            count += 1
            items = line.split( "`" )
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f"<br></code></pre>"
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace( "`", r"\`" )
                    line = line.replace( "<", "&lt;" )
                    line = line.replace( ">", "&gt;" )
                    line = line.replace( " ", "&nbsp;" )
                    line = line.replace( "*", "&ast;" )
                    line = line.replace( "_", "&lowbar;" )
                    line = line.replace( "-", "&#45;" )
                    line = line.replace( ".", "&#46;" )
                    line = line.replace( "!", "&#33;" )
                    line = line.replace( "(", "&#40;" )
                    line = line.replace( ")", "&#41;" )
                    line = line.replace( "$", "&#36;" )
                lines[i] = "<br>" + line
    text = "".join( lines )
    return text


def _launch_demo(args, model, tokenizer, config, system, max_new_tokens,times):
    max_tokens = gr.State( max_new_tokens )  # Initialize max tokens to 512

    def predict(_query, _chatbot, _task_history):
        print( "User Input:", _query )
        _query = _parse_text( _query )

        # 计时开始
        time_start = time.time()  # 开始计时

        # print( f"User: {_query}" )
        _chatbot.append( (_query, "") )

        full_response = ""
        generated_tokens = 0
        # _title_list = []
        titles = ''

        for i in range( times ):
            for response in model.chat_stream( tokenizer, _query,
                                               max_new_tokens=max_new_tokens, 
                                               history=_task_history,
                                               system=system,
                                               generation_config=config,
                                               ):
                # _chatbot[-1] = (_query, _parse_text( response ))
                # yield _chatbot

                full_response = _parse_text( response )

                if len( full_response ) >= max_tokens.value:
                    print( f"Reached max tokens ({max_tokens.value}): Stopping generation." )
                    break
            # _title_list.append( full_response )
            titles += full_response + '<br>'  # gradio 不支持“\n”

            # 统一到一起输出
            _chatbot.append( (_query, titles) )
            yield _chatbot[1:]

        # print( f"History: {_task_history}" )
        print( f"system: {system}" )
        _task_history.append( (_query, full_response) )
        # print( f"Qwen-Chat: {_parse_text( full_response )}" )
        print( f"Qwen-Chat: {titles}" )
        print( f"输出{len( _parse_text( full_response ) )}个字符" )
        cost_time( time_start )
        print( "Done".center( 50, "*" ) )

    def regenerate(_chatbot, _task_history):
        if not _task_history:
            yield _chatbot
            return
        item = _task_history.pop( -1 )
        _chatbot.pop( -1 )
        yield from predict( item[0], _chatbot, _task_history )

    def continue_chat(_chatbot, _task_history):
        '''
        点击继续
        '''
        if not _task_history:
            yield _chatbot
            return
        yield from predict( "继续", _chatbot, _task_history )

    def reset_user_input():
        return gr.update( value="" )

    def reset_state(_chatbot, _task_history):
        _task_history.clear()
        _chatbot.clear()
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        return _chatbot

    with gr.Blocks() as demo:
        gr.Markdown( """""" )
        gr.Markdown(
            """\
<center><font size=5>大模型</center>""" )
        gr.Markdown( """""" )

        chatbot = gr.Chatbot( label='大模型', elem_classes="control-height", height=500 )
        query = gr.Textbox( lines=2, label='输入' )
        task_history = gr.State( [] )

        with gr.Row():
            submit_btn = gr.Button( "🚀 Submit (发送)" )
            empty_btn = gr.Button( "🧹 Clear History (清除历史)" )
            # continue_btn = gr.Button( "🤔 Continue (继续)" )
            regen_btn = gr.Button( "️🔄 Regenerate (重试)" )

        submit_btn.click( predict, [query, chatbot, task_history], [chatbot], show_progress=True )
        submit_btn.click( reset_user_input, [], [query] )
        empty_btn.click( reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True )
        regen_btn.click( regenerate, [chatbot, task_history], [chatbot], show_progress=True )
        # continue_btn.click( continue_chat, [chatbot, task_history], [chatbot], show_progress=True )

    demo.queue().launch(
        share=args.share,
        inbrowser=args.inbrowser,
        server_port=args.server_port,
        server_name=args.server_name,
    )


def main(lora_type):
    args = _get_args()

    model, tokenizer, config, system, max_new_tokens,times = _load_model_tokenizer( args, lora_type )

    _launch_demo( args, model, tokenizer, config, system,max_new_tokens ,times)


if __name__ == '__main__':
    lora_type = 'main_title'  
    # lora_type = 'little_title' 
    # lora_type = 'main_content' 

    main( lora_type )



  • 写回答

6条回答 默认 最新

  • 叫兽-郭老师 Java领域新星创作者 2024-03-11 16:15
    关注

    🌈🌈🌈参考通义千问和郭老师的小迷弟雅思莫了-编写提供🌈🌈🌈
    您可以参考如下,如果回答的不正确,及时评论区回复我,我会根据你错误描述追加回复,直到您满意为止。

    为了实现您的需求,您需要在gradio界面中添加一个下拉列表,并根据选择的lora_type值重新加载_launch_demo函数。以下是一个修改后的代码示例:

    import gradio as gr
    
    # ... (保持原有代码不变)
    
    def _get_args():
        # ... (保持原有代码不变)
        parser.add_argument("--lora-type", type=str, default='main_title', help="Select lora type: 'main_content', 'main_title' or 'little_title'")
        args = parser.parse_args()
        return args
    
    def main():
        args = _get_args()
        lora_type = args.lora_type
        model, tokenizer, config, system, max_new_tokens, times = _load_model_tokenizer(args, lora_type)
    
        def dropdown_callback(value):
            global lora_type
            lora_type = value
            chatbot.clear()
            task_history.clear()
            _launch_demo(args, model, tokenizer, config, system, max_new_tokens, times)
    
        with gr.Blocks() as demo:
            # ... (保持原有Markdown和Chatbot等组件代码不变)
            
            # 添加下拉列表组件
            lora_dropdown = gr.Dropdown(
                options=['main_content', 'main_title', 'little_title'],
                label="Select Lora Type",
                value=lora_type,
                description="Lora Type Selection",
                callback=dropdown_callback
            )
            
            # 将下拉列表添加到布局中
            demo.add_row(lora_dropdown)
    
        _launch_demo(args, model, tokenizer, config, system, max_new_tokens, times)
    
    if __name__ == '__main__':
        main()
    

    在这个修改后的代码中,我们首先在_get_args中增加了一个命令行参数--lora-type来设置默认的lora_type。然后,在main函数中创建了一个gradio.Dropdown组件,其中包含了三种类型的选项,并设置了初始值为命令行参数或默认值。

    当用户更改下拉列表的选择时,会触发dropdown_callback回调函数,该函数将更新全局变量lora_type并清除现有聊天记录与历史任务。之后,调用_launch_demo函数以新的lora_type重新加载模型和启动演示。

    注意:由于gradio本身不直接支持动态重载模型,这里采用了清除现有内容并重新加载的方法,这会导致页面刷新和之前输入信息丢失。如果需要更平滑地切换模型而不刷新整个页面,可能需要进一步定制gradio或考虑使用其他前端框架(如React.js)结合gradio实现更复杂的交互逻辑。

    评论

报告相同问题?

问题事件

  • 已结题 (查看结题原因) 3月11日
  • 修改了问题 3月11日
  • 修改了问题 3月11日
  • 修改了问题 3月11日
  • 展开全部

悬赏问题

  • ¥20 指导如何跑通以下两个Github代码
  • ¥15 大家知道这个后备文件怎么删吗,为啥这些文件我只看到一份,没有后备呀
  • ¥15 C++为什么这个代码没报错运行不出来啊
  • ¥15 一道ban了很多东西的pyjail题
  • ¥15 关于#r语言#的问题:如何将生成的四幅图排在一起,且对变量的赋值进行更改,让组合的图漂亮、美观@(相关搜索:森林图)
  • ¥15 C++识别堆叠物体异常
  • ¥15 微软硬件驱动认证账号申请
  • ¥15 GPT写作提示指令词
  • ¥20 根据动态演化博弈支付矩阵完成复制动态方程求解和演化相图分析等
  • ¥15 华为超融合部署环境下RedHat虚拟机分区扩容问题