TheoremExplainAgent / src /core /code_generator.py
dfdfdsfgs's picture
Upload project files
d9486d1
import os
import re
import json
from typing import Union, List, Dict
from PIL import Image
import glob
from src.utils.utils import extract_json
from mllm_tools.utils import _prepare_text_inputs, _extract_code, _prepare_text_image_inputs
from mllm_tools.gemini import GeminiWrapper
from mllm_tools.vertex_ai import VertexAIWrapper
from task_generator import (
get_prompt_code_generation,
get_prompt_fix_error,
get_prompt_visual_fix_error,
get_banned_reasonings,
get_prompt_rag_query_generation_fix_error,
get_prompt_context_learning_code,
get_prompt_rag_query_generation_code
)
from task_generator.prompts_raw import (
_code_font_size,
_code_disable,
_code_limit,
_prompt_manim_cheatsheet
)
from src.rag.vector_store import RAGVectorStore # Import RAGVectorStore
class CodeGenerator:
"""A class for generating and managing Manim code."""
def __init__(self, scene_model, helper_model, output_dir="output", print_response=False, use_rag=False, use_context_learning=False, context_learning_path="data/context_learning", chroma_db_path="rag/chroma_db", manim_docs_path="rag/manim_docs", embedding_model="azure/text-embedding-3-large", use_visual_fix_code=False, use_langfuse=True, session_id=None):
"""Initialize the CodeGenerator.
Args:
scene_model: The model used for scene generation
helper_model: The model used for helper tasks
output_dir (str, optional): Directory for output files. Defaults to "output".
print_response (bool, optional): Whether to print model responses. Defaults to False.
use_rag (bool, optional): Whether to use RAG. Defaults to False.
use_context_learning (bool, optional): Whether to use context learning. Defaults to False.
context_learning_path (str, optional): Path to context learning examples. Defaults to "data/context_learning".
chroma_db_path (str, optional): Path to ChromaDB. Defaults to "rag/chroma_db".
manim_docs_path (str, optional): Path to Manim docs. Defaults to "rag/manim_docs".
embedding_model (str, optional): Name of embedding model. Defaults to "azure/text-embedding-3-large".
use_visual_fix_code (bool, optional): Whether to use visual code fixing. Defaults to False.
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True.
session_id (str, optional): Session identifier. Defaults to None.
"""
self.scene_model = scene_model
self.helper_model = helper_model
self.output_dir = output_dir
self.print_response = print_response
self.use_rag = use_rag
self.use_context_learning = use_context_learning
self.context_learning_path = context_learning_path
self.context_examples = self._load_context_examples() if use_context_learning else None
self.manim_docs_path = manim_docs_path
self.use_visual_fix_code = use_visual_fix_code
self.banned_reasonings = get_banned_reasonings()
self.session_id = session_id # Use session_id passed from VideoGenerator
if use_rag:
self.vector_store = RAGVectorStore(
chroma_db_path=chroma_db_path,
manim_docs_path=manim_docs_path,
embedding_model=embedding_model,
session_id=self.session_id,
use_langfuse=use_langfuse
)
else:
self.vector_store = None
def _load_context_examples(self) -> str:
"""Load all context learning examples from the specified directory.
Returns:
str: Formatted context learning examples, or None if no examples found.
"""
examples = []
for example_file in glob.glob(f"{self.context_learning_path}/**/*.py", recursive=True):
with open(example_file, 'r') as f:
examples.append(f"# Example from {os.path.basename(example_file)}\n{f.read()}\n")
# Format examples using get_prompt_context_learning_code instead of _prompt_context_learning
if examples:
formatted_examples = get_prompt_context_learning_code(
examples="\n".join(examples)
)
return formatted_examples
return None
def _generate_rag_queries_code(self, implementation: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
"""Generate RAG queries from the implementation plan.
Args:
implementation (str): The implementation plan text
scene_trace_id (str, optional): Trace ID for the scene. Defaults to None.
topic (str, optional): Topic of the scene. Defaults to None.
scene_number (int, optional): Scene number. Defaults to None.
session_id (str, optional): Session identifier. Defaults to None.
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list.
Returns:
List[str]: List of generated RAG queries
"""
# Create a cache key for this scene
cache_key = f"{topic}_scene{scene_number}"
# Check if we already have a cache file for this scene
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_code.json")
# If cache file exists, load and return cached queries
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
cached_queries = json.load(f)
print(f"Using cached RAG queries for {cache_key}")
return cached_queries
# Generate new queries if not cached
if relevant_plugins:
prompt = get_prompt_rag_query_generation_code(implementation, ", ".join(relevant_plugins))
else:
prompt = get_prompt_rag_query_generation_code(implementation, "No plugins are relevant.")
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag_query_generation", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
print(f"RAG queries: {queries}")
# retreive json triple backticks
try: # add try-except block to handle potential json decode errors
queries = re.search(r'```json(.*)```', queries, re.DOTALL).group(1)
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}")
print(f"Response text was: {queries}")
return [] # Return empty list in case of parsing error
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]:
"""Generate RAG queries for fixing code errors.
Args:
error (str): The error message to fix
code (str): The code containing the error
scene_trace_id (str, optional): Trace ID for the scene. Defaults to None.
topic (str, optional): Topic of the scene. Defaults to None.
scene_number (int, optional): Scene number. Defaults to None.
session_id (str, optional): Session identifier. Defaults to None.
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list.
Returns:
List[str]: List of generated RAG queries for error fixing
"""
# Create a cache key for this scene and error
cache_key = f"{topic}_scene{scene_number}_error_fix"
# Check if we already have a cache file for error fix queries
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache")
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json")
# If cache file exists, load and return cached queries
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
cached_queries = json.load(f)
print(f"Using cached RAG queries for error fix in {cache_key}")
return cached_queries
# Generate new queries for error fix if not cached
prompt = get_prompt_rag_query_generation_fix_error(
error=error,
code=code,
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant."
)
queries = self.helper_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
# remove json triple backticks
queries = queries.replace("```json", "").replace("```", "")
try: # add try-except block to handle potential json decode errors
queries = json.loads(queries)
except json.JSONDecodeError as e:
print(f"JSONDecodeError when parsing RAG queries for error fix: {e}")
print(f"Response text was: {queries}")
return [] # Return empty list in case of parsing error
# Cache the queries
with open(cache_file, 'w') as f:
json.dump(queries, f)
return queries
def _extract_code_with_retries(self, response_text: str, pattern: str, generation_name: str = None, trace_id: str = None, session_id: str = None, max_retries: int = 10) -> str:
"""Extract code from response text with retry logic.
Args:
response_text (str): The text containing code to extract
pattern (str): Regex pattern for extracting code
generation_name (str, optional): Name of generation step. Defaults to None.
trace_id (str, optional): Trace identifier. Defaults to None.
session_id (str, optional): Session identifier. Defaults to None.
max_retries (int, optional): Maximum number of retries. Defaults to 10.
Returns:
str: The extracted code
Raises:
ValueError: If code extraction fails after max retries
"""
retry_prompt = """
Please extract the Python code in the correct format using the pattern: {pattern}.
You MUST NOT include any other text or comments.
You MUST return the exact same code as in the previous response, NO CONTENT EDITING is allowed.
Previous response:
{response_text}
"""
for attempt in range(max_retries):
code_match = re.search(pattern, response_text, re.DOTALL)
if code_match:
return code_match.group(1)
if attempt < max_retries - 1:
print(f"Attempt {attempt + 1}: Failed to extract code pattern. Retrying...")
# Regenerate response with a more explicit prompt
response_text = self.scene_model(
_prepare_text_inputs(retry_prompt.format(pattern=pattern, response_text=response_text)),
metadata={
"generation_name": f"{generation_name}_format_retry_{attempt + 1}",
"trace_id": trace_id,
"session_id": session_id
}
)
raise ValueError(f"Failed to extract code pattern after {max_retries} attempts. Pattern: {pattern}")
def generate_manim_code(self,
topic: str,
description: str,
scene_outline: str,
scene_implementation: str,
scene_number: int,
additional_context: Union[str, List[str]] = None,
scene_trace_id: str = None,
session_id: str = None,
rag_queries_cache: Dict = None) -> str:
"""Generate Manim code from video plan.
Args:
topic (str): Topic of the scene
description (str): Description of the scene
scene_outline (str): Outline of the scene
scene_implementation (str): Implementation details
scene_number (int): Scene number
additional_context (Union[str, List[str]], optional): Additional context. Defaults to None.
scene_trace_id (str, optional): Trace identifier. Defaults to None.
session_id (str, optional): Session identifier. Defaults to None.
rag_queries_cache (Dict, optional): Cache for RAG queries. Defaults to None.
Returns:
Tuple[str, str]: Generated code and response text
"""
if self.use_context_learning:
# Add context examples to additional_context
if additional_context is None:
additional_context = []
elif isinstance(additional_context, str):
additional_context = [additional_context]
# Now using the properly formatted code examples
if self.context_examples:
additional_context.append(self.context_examples)
if self.use_rag:
# Generate RAG queries (will use cache if available)
rag_queries = self._generate_rag_queries_code(
implementation=scene_implementation,
scene_trace_id=scene_trace_id,
topic=topic,
scene_number=scene_number,
session_id=session_id
)
retrieved_docs = self.vector_store.find_relevant_docs(
queries=rag_queries,
k=2, # number of documents to retrieve
trace_id=scene_trace_id,
topic=topic,
scene_number=scene_number
)
# Format the retrieved documents into a string
if additional_context is None:
additional_context = []
additional_context.append(retrieved_docs)
# Format code generation prompt with plan and retrieved context
prompt = get_prompt_code_generation(
scene_outline=scene_outline,
scene_implementation=scene_implementation,
topic=topic,
description=description,
scene_number=scene_number,
additional_context=additional_context
)
# Generate code using model
response_text = self.scene_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "code_generation", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
# Extract code with retries
code = self._extract_code_with_retries(
response_text,
r"```python(.*)```",
generation_name="code_generation",
trace_id=scene_trace_id,
session_id=session_id
)
return code, response_text
def fix_code_errors(self, implementation_plan: str, code: str, error: str, scene_trace_id: str, topic: str, scene_number: int, session_id: str, rag_queries_cache: Dict = None) -> str:
"""Fix errors in generated Manim code.
Args:
implementation_plan (str): Original implementation plan
code (str): Code containing errors
error (str): Error message to fix
scene_trace_id (str): Trace identifier
topic (str): Topic of the scene
scene_number (int): Scene number
session_id (str): Session identifier
rag_queries_cache (Dict, optional): Cache for RAG queries. Defaults to None.
Returns:
Tuple[str, str]: Fixed code and response text
"""
# Format error fix prompt
prompt = get_prompt_fix_error(implementation_plan=implementation_plan, manim_code=code, error=error)
if self.use_rag:
# Generate RAG queries for error fixing
rag_queries = self._generate_rag_queries_error_fix(
error=error,
code=code,
scene_trace_id=scene_trace_id,
topic=topic,
scene_number=scene_number,
session_id=session_id
)
retrieved_docs = self.vector_store.find_relevant_docs(
queries=rag_queries,
k=2, # number of documents to retrieve for error fixing
trace_id=scene_trace_id,
topic=topic,
scene_number=scene_number
)
# Format the retrieved documents into a string
prompt = get_prompt_fix_error(implementation_plan=implementation_plan, manim_code=code, error=error, additional_context=retrieved_docs)
# Get fixed code from model
response_text = self.scene_model(
_prepare_text_inputs(prompt),
metadata={"generation_name": "code_fix_error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id}
)
# Extract fixed code with retries
fixed_code = self._extract_code_with_retries(
response_text,
r"```python(.*)```",
generation_name="code_fix_error",
trace_id=scene_trace_id,
session_id=session_id
)
return fixed_code, response_text
def visual_self_reflection(self, code: str, media_path: Union[str, Image.Image], scene_trace_id: str, topic: str, scene_number: int, session_id: str) -> str:
"""Use snapshot image or mp4 video to fix code.
Args:
code (str): Code to fix
media_path (Union[str, Image.Image]): Path to media file or PIL Image
scene_trace_id (str): Trace identifier
topic (str): Topic of the scene
scene_number (int): Scene number
session_id (str): Session identifier
Returns:
Tuple[str, str]: Fixed code and response text
"""
# Determine if we're dealing with video or image
is_video = isinstance(media_path, str) and media_path.endswith('.mp4')
# Load prompt template
with open('task_generator/prompts_raw/prompt_visual_self_reflection.txt', 'r') as f:
prompt_template = f.read()
# Format prompt
prompt = prompt_template.format(code=code)
# Prepare input based on media type
if is_video and isinstance(self.scene_model, (GeminiWrapper, VertexAIWrapper)):
# For video with Gemini models
messages = [
{"type": "text", "content": prompt},
{"type": "video", "content": media_path}
]
else:
# For images or non-Gemini models
if isinstance(media_path, str):
media = Image.open(media_path)
else:
media = media_path
messages = [
{"type": "text", "content": prompt},
{"type": "image", "content": media}
]
# Get model response
response_text = self.scene_model(
messages,
metadata={
"generation_name": "visual_self_reflection",
"trace_id": scene_trace_id,
"tags": [topic, f"scene{scene_number}"],
"session_id": session_id
}
)
# Extract code with retries
fixed_code = self._extract_code_with_retries(
response_text,
r"```python(.*)```",
generation_name="visual_self_reflection",
trace_id=scene_trace_id,
session_id=session_id
)
return fixed_code, response_text