File size: 5,232 Bytes
f95f08f
8e5c202
 
5c79723
f95f08f
5c79723
 
f95f08f
8e5c202
5c79723
8e5c202
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
 
 
5c79723
 
 
8e5c202
 
 
 
 
 
5c79723
8e5c202
 
 
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
5c79723
8e5c202
5c79723
8e5c202
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
5c79723
8e5c202
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c79723
8e5c202
 
 
 
 
 
 
 
 
 
5c79723
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os

# Make sure to import the 'spaces' library
import spaces # <--- ADD THIS OR ENSURE IT'S UNCOMMENTED

# --- Configuration ---
HF_TOKEN = os.getenv("HF_TOKEN")

MODEL_OPTIONS = {
    "Qwen1.5-1.8B-Chat": "Qwen/Qwen1.5-1.8B-Chat",
    "Qwen2.5-Coder-3B": "Qwen/Qwen2.5-Coder-3B",
}

# --- Model Loading Cache ---
loaded_models = {}

def get_model_and_tokenizer(model_name_key):
    if model_name_key not in loaded_models:
        model_id = MODEL_OPTIONS[model_name_key]
        print(f"Loading model: {model_id}...")
        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype="auto",
                device_map="auto",
                token=HF_TOKEN
            )
            tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
            loaded_models[model_name_key] = (model, tokenizer)
            print(f"Model {model_id} loaded successfully.")
        except Exception as e:
            print(f"Error loading model {model_id}: {e}")
            if model_name_key in loaded_models:
                del loaded_models[model_name_key]
            raise gr.Error(f"Failed to load model {model_name_key}. Please check the model ID and your Hugging Face token permissions. Error: {e}")
    return loaded_models[model_name_key]

# --- Inference Function ---
@spaces.GPU(duration=120) # <--- ADD THIS DECORATOR (adjust duration if needed)
def generate_response(prompt_text, model_choice, max_new_tokens=512, temperature=0.7, top_p=0.9):
    if not prompt_text:
        return "Please enter a prompt."
    if not model_choice:
        return "Please select a model."

    try:
        model, tokenizer = get_model_and_tokenizer(model_choice)
    except Exception as e:
        return str(e)

    device = model.device

    if "Chat" in model_choice:
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt_text}
        ]
        try:
            input_text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        except Exception as e:
            print(f"Warning: Could not apply chat template for {model_choice}: {e}. Using prompt as is.")
            input_text = prompt_text
    else:
        input_text = prompt_text

    model_inputs = tokenizer([input_text], return_tensors="pt").to(device)

    try:
        generated_ids = model.generate(
            model_inputs.input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True
        )
        response_ids = generated_ids[0][model_inputs.input_ids.shape[-1]:]
        response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
    except Exception as e:
        print(f"Error during generation with {model_choice}: {e}")
        return f"Error generating response: {e}"

    return response_text

# --- Gradio Interface ---
# (Rest of your Gradio code remains the same)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# LLM Coding & Math Experiment")
    gr.Markdown("Query Qwen1.5-1.8B-Chat or Qwen Code models using ZeroGPU.")

    with gr.Row():
        model_dropdown = gr.Dropdown(
            label="Select Model",
            choices=list(MODEL_OPTIONS.keys()),
            value=list(MODEL_OPTIONS.keys())[0]
        )
    with gr.Row():
        prompt_input = gr.Textbox(label="Enter your prompt:", lines=4, placeholder="e.g., Write a Python function to calculate factorial, or What is the capital of France?")
    with gr.Row():
        output_text = gr.Textbox(label="Model Response:", lines=8, interactive=False)

    with gr.Row():
        submit_button = gr.Button("Generate Response", variant="primary")

    with gr.Accordion("Advanced Settings", open=False):
        max_new_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=512, step=32, label="Max New Tokens")
        temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature")
        top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P")

    submit_button.click(
        fn=generate_response,
        inputs=[prompt_input, model_dropdown, max_new_tokens_slider, temperature_slider, top_p_slider],
        outputs=output_text,
        api_name="generate"
    )

    gr.Markdown("## Notes:")
    gr.Markdown(
        "- Ensure you have accepted the terms of use for the selected Qwen models on the Hugging Face Hub.\n"
        "- Model loading can take some time, especially on the first run or when switching models.\n"
        "- This Space runs on ZeroGPU, which means GPU resources are allocated dynamically."
    )

if __name__ == "__main__":
    # The logs show "Running on local URL: http://0.0.0.0:7860" which implies it's likely using the default Gradio launch.
    # No changes needed here unless you want to explicitly set share=True for a public link when testing locally (not for Spaces deployment itself).
    demo.launch()