File size: 7,666 Bytes
31e2261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import spaces
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import json

LEAN4_DEFAULT_HEADER = (
    "import Mathlib\n"
    "import Aesop\n\n"
    "set_option maxHeartbeats 0\n\n"
    "open BigOperators Real Nat Topology Rat\n"
)

title = "# 🙋🏻‍♂️Welcome to 🌟Tonic's 🌕💉👨🏻‍🔬Moonshot Math"

description = """
     **Kimina-Prover-72B** is a state-of-the-art large formal reasoning model for Lean 4, achieving **80%+ pass rate** on the miniF2F benchmark, outperforming all prior works.\
Trained with Reinforcement Learning, 72B parameters, and a 32K token context window.\
- [Kimina-Prover-Preview GitHub](https://github.com/MoonshotAI/Kimina-Prover-Preview)\
- [Hugging Face: AI-MO/Kimina-Prover-72B](https://huggingface.co/AI-MO/Kimina-Prover-72B)\
- [Kimina Prover blog](https://huggingface.co/blog/AI-MO/kimina-prover)\
- [unimath dataset](https://huggingface.co/datasets/introspector/unimath)\
"""

citation = """> **Citation:**
> ```
> @article{kimina_prover_2025,
>   title = {Kimina-Prover Preview: Towards Large Formal Reasoning Models with Reinforcement Learning},
>   author = {Wang, Haiming and Unsal, Mert and ...},
>   year = {2025},
>   url = {http://arxiv.org/abs/2504.11354},
> }
> ```
"""


joinus ="""
### Join us:
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻  
[![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP)  
On 🤗Huggingface: [MultiTransformer](https://huggingface.co/MultiTransformer)  
On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)  
🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

# Build the initial system prompt
SYSTEM_PROMPT = "You are an expert in mathematics and Lean 4."

# Helper to build a Lean4 code block
def build_formal_block(formal_statement, informal_prefix=""):
    return (
        f"{LEAN4_DEFAULT_HEADER}\n"
        f"{informal_prefix}\n"
        f"{formal_statement}"
    )

# Helper to extract the first Lean4 code block from text
def extract_lean4_code(text):
    code_block = re.search(r"```lean4(.*?)(```|$)", text, re.DOTALL)
    if code_block:
        code = code_block.group(1)
        lines = [line for line in code.split('\n') if line.strip()]
        return '\n'.join(lines)
    return text.strip()

# Example problems
unimath1 = """Goal:
  X : UU
  Y : UU
  P : UU
  xp : (X → P) → P
  yp : (Y → P) → P
  X0 : X × Y → P
  x : X
  ============================
   (Y → P)"""

unimath2 = """Goal:
    R : ring  M : module R
  ============================
   (islinear (idfun M))"""

unimath3 = """Goal:
    X : UU  i : nat  b : hProptoType (i < S i)  x : Vector X (S i)  r : i = i
  ============================
   (pr1 lastelement = pr1 (i,, b))"""

unimath4 = """Goal:
    X : dcpo  CX : continuous_dcpo_struct X  x : pr1hSet X  y : pr1hSet X
  ============================
   (x ⊑ y ≃ (∀ i : approximating_family CX x, approximating_family CX x i ⊑ y))"""

additional_info_prompt = "/-Explain using mathematics-/\n"

examples = [
    [unimath1, additional_info_prompt, 2500],
    [unimath2, additional_info_prompt, 2500],
    [unimath3, additional_info_prompt, 2500],
    [unimath4, additional_info_prompt, 2500]
]

model_name = "AI-MO/Kimina-Prover-72B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)

# Set generation config
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model.generation_config.do_sample = True
model.generation_config.temperature = 0.6
model.generation_config.top_p = 0.95

# Initialize chat history with system prompt
def init_chat(formal_statement, informal_prefix):
    user_prompt = (
        "Think about and solve the following problem step by step in Lean 4.\n"
        "# Problem: Provide a formal proof for the following statement.\n"
        f"# Formal statement:\n```lean4\n{build_formal_block(formal_statement, informal_prefix)}\n```\n"
    )
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]

# Gradio chat handler
@spaces.GPU
def chat_handler(user_message, informal_prefix, max_tokens, chat_history):
    # If chat_history is empty, initialize with system and first user message
    if not chat_history or len(chat_history) < 2:
        chat_history = init_chat(user_message, informal_prefix)
        display_history = [("user", user_message)]
    else:
        # Append new user message
        chat_history.append({"role": "user", "content": user_message})
        display_history = []
        for msg in chat_history:
            if msg["role"] == "user":
                display_history.append(("user", msg["content"]))
            elif msg["role"] == "assistant":
                display_history.append(("assistant", msg["content"]))
    # Format prompt using chat template
    prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    attention_mask = torch.ones_like(input_ids)
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_tokens + input_ids.shape[1],
        pad_token_id=model.generation_config.pad_token_id,
        temperature=model.generation_config.temperature,
        top_p=model.generation_config.top_p,
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract only the new assistant message (after the prompt)
    new_response = result[len(prompt):].strip()
    # Add assistant message to chat history
    chat_history.append({"role": "assistant", "content": new_response})
    display_history.append(("assistant", new_response))
    # Extract Lean4 code
    code = extract_lean4_code(new_response)
    # Prepare output
    output_data = {
        "model_input": prompt,
        "model_output": result,
        "lean4_code": code,
        "chat_history": chat_history
    }
    return display_history, json.dumps(output_data, indent=2), code, chat_history

def main():
    with gr.Blocks() as demo:
        # Title and Model Description
        gr.Markdown("""# 🙋🏻‍♂️Welcome to 🌟Tonic's 🌕💉👨🏻‍🔬Moonshot Math""")
        gr.Markdown(description)
        gr.Markdown(joinus)        
        with gr.Row():
            with gr.Column():
                chat = gr.Chatbot(label="Chat History")
                user_input = gr.Textbox(label="Your message or formal statement", lines=4)
                informal = gr.Textbox(value=additional_info_prompt, label="Optional informal prefix")
                max_tokens = gr.Slider(minimum=150, maximum=4096, value=2500, label="Max Tokens")
                submit = gr.Button("Send")
            with gr.Column():
                json_out = gr.JSON(label="Full Output")
                code_out = gr.Code(label="Extracted Lean4 Code", language="lean4")
        state = gr.State([])
        # On submit, call chat_handler
        submit.click(chat_handler, [user_input, informal, max_tokens, state], [chat, json_out, code_out, state])
        gr.Markdown(citation)
    demo.launch()

if __name__ == "__main__":
    main()