Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import re | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
import torch | |
import json | |
LEAN4_DEFAULT_HEADER = ( | |
"import Mathlib\n" | |
"import Aesop\n\n" | |
"set_option maxHeartbeats 0\n\n" | |
"open BigOperators Real Nat Topology Rat\n" | |
) | |
title = "# ๐๐ปโโ๏ธWelcome to ๐Tonic's ๐๐๐จ๐ปโ๐ฌMoonshot Math" | |
description = """ | |
**Kimina-Prover-72B** is a state-of-the-art large formal reasoning model for Lean 4, achieving **80%+ pass rate** on the miniF2F benchmark, outperforming all prior works.\ | |
Trained with Reinforcement Learning, 72B parameters, and a 32K token context window.\ | |
- [Kimina-Prover-Preview GitHub](https://github.com/MoonshotAI/Kimina-Prover-Preview)\ | |
- [Hugging Face: AI-MO/Kimina-Prover-72B](https://huggingface.co/AI-MO/Kimina-Prover-72B)\ | |
- [Kimina Prover blog](https://huggingface.co/blog/AI-MO/kimina-prover)\ | |
- [unimath dataset](https://huggingface.co/datasets/introspector/unimath)\ | |
""" | |
citation = """> **Citation:** | |
> ``` | |
> @article{kimina_prover_2025, | |
> title = {Kimina-Prover Preview: Towards Large Formal Reasoning Models with Reinforcement Learning}, | |
> author = {Wang, Haiming and Unsal, Mert and ...}, | |
> year = {2025}, | |
> url = {http://arxiv.org/abs/2504.11354}, | |
> } | |
> ``` | |
""" | |
joinus =""" | |
### Join us: | |
๐TeamTonic๐ is always making cool demos! Join our active builder's ๐ ๏ธcommunity ๐ป | |
[](https://discord.gg/qdfnvSPcqP) | |
On ๐คHuggingface: [MultiTransformer](https://huggingface.co/MultiTransformer) | |
On ๐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to๐ [Build Tonic](https://git.tonic-ai.com/contribute) | |
๐คBig thanks to Yuvi Sharma and all the folks at huggingface for the community grant ๐ค | |
""" | |
# Build the initial system prompt | |
SYSTEM_PROMPT = "You are an expert in mathematics and Lean 4." | |
# Helper to build a Lean4 code block | |
def build_formal_block(formal_statement, informal_prefix=""): | |
return ( | |
f"{LEAN4_DEFAULT_HEADER}\n" | |
f"{informal_prefix}\n" | |
f"{formal_statement}" | |
) | |
# Helper to extract the first Lean4 code block from text | |
def extract_lean4_code(text): | |
code_block = re.search(r"```lean4(.*?)(```|$)", text, re.DOTALL) | |
if code_block: | |
code = code_block.group(1) | |
lines = [line for line in code.split('\n') if line.strip()] | |
return '\n'.join(lines) | |
return text.strip() | |
# Example problems | |
unimath1 = """Goal: | |
X : UU | |
Y : UU | |
P : UU | |
xp : (X โ P) โ P | |
yp : (Y โ P) โ P | |
X0 : X ร Y โ P | |
x : X | |
============================ | |
(Y โ P)""" | |
unimath2 = """Goal: | |
R : ring M : module R | |
============================ | |
(islinear (idfun M))""" | |
unimath3 = """Goal: | |
X : UU i : nat b : hProptoType (i < S i) x : Vector X (S i) r : i = i | |
============================ | |
(pr1 lastelement = pr1 (i,, b))""" | |
unimath4 = """Goal: | |
X : dcpo CX : continuous_dcpo_struct X x : pr1hSet X y : pr1hSet X | |
============================ | |
(x โ y โ (โ i : approximating_family CX x, approximating_family CX x i โ y))""" | |
additional_info_prompt = "/-Explain using mathematics-/\n" | |
examples = [ | |
[unimath1, additional_info_prompt, 2500], | |
[unimath2, additional_info_prompt, 2500], | |
[unimath3, additional_info_prompt, 2500], | |
[unimath4, additional_info_prompt, 2500] | |
] | |
model_name = "AI-MO/Kimina-Prover-72B" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) | |
# Set generation config | |
model.generation_config = GenerationConfig.from_pretrained(model_name) | |
model.generation_config.pad_token_id = model.generation_config.eos_token_id | |
model.generation_config.do_sample = True | |
model.generation_config.temperature = 0.6 | |
model.generation_config.top_p = 0.95 | |
# Initialize chat history with system prompt | |
def init_chat(formal_statement, informal_prefix): | |
user_prompt = ( | |
"Think about and solve the following problem step by step in Lean 4.\n" | |
"# Problem: Provide a formal proof for the following statement.\n" | |
f"# Formal statement:\n```lean4\n{build_formal_block(formal_statement, informal_prefix)}\n```\n" | |
) | |
return [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{"role": "user", "content": user_prompt} | |
] | |
# Gradio chat handler | |
def chat_handler(user_message, informal_prefix, max_tokens, chat_history): | |
# If chat_history is empty, initialize with system and first user message | |
if not chat_history or len(chat_history) < 2: | |
chat_history = init_chat(user_message, informal_prefix) | |
display_history = [("user", user_message)] | |
else: | |
# Append new user message | |
chat_history.append({"role": "user", "content": user_message}) | |
display_history = [] | |
for msg in chat_history: | |
if msg["role"] == "user": | |
display_history.append(("user", msg["content"])) | |
elif msg["role"] == "assistant": | |
display_history.append(("assistant", msg["content"])) | |
# Format prompt using chat template | |
prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True) | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
attention_mask = torch.ones_like(input_ids) | |
outputs = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_length=max_tokens + input_ids.shape[1], | |
pad_token_id=model.generation_config.pad_token_id, | |
temperature=model.generation_config.temperature, | |
top_p=model.generation_config.top_p, | |
) | |
result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the new assistant message (after the prompt) | |
new_response = result[len(prompt):].strip() | |
# Add assistant message to chat history | |
chat_history.append({"role": "assistant", "content": new_response}) | |
display_history.append(("assistant", new_response)) | |
# Extract Lean4 code | |
code = extract_lean4_code(new_response) | |
# Prepare output | |
output_data = { | |
"model_input": prompt, | |
"model_output": result, | |
"lean4_code": code, | |
"chat_history": chat_history | |
} | |
return display_history, json.dumps(output_data, indent=2), code, chat_history | |
def main(): | |
with gr.Blocks() as demo: | |
# Title and Model Description | |
gr.Markdown("""# ๐๐ปโโ๏ธWelcome to ๐Tonic's ๐๐๐จ๐ปโ๐ฌMoonshot Math""") | |
gr.Markdown(description) | |
gr.Markdown(joinus) | |
with gr.Row(): | |
with gr.Column(): | |
chat = gr.Chatbot(label="Chat History") | |
user_input = gr.Textbox(label="Your message or formal statement", lines=4) | |
informal = gr.Textbox(value=additional_info_prompt, label="Optional informal prefix") | |
max_tokens = gr.Slider(minimum=150, maximum=4096, value=2500, label="Max Tokens") | |
submit = gr.Button("Send") | |
with gr.Column(): | |
json_out = gr.JSON(label="Full Output") | |
code_out = gr.Code(label="Extracted Lean4 Code", language="lean4") | |
state = gr.State([]) | |
# On submit, call chat_handler | |
submit.click(chat_handler, [user_input, informal, max_tokens, state], [chat, json_out, code_out, state]) | |
gr.Markdown(citation) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |