Spaces:
Runtime error
Runtime error
import os | |
import re | |
import json | |
from typing import List, Dict | |
from mllm_tools.utils import _prepare_text_inputs | |
from task_generator import ( | |
get_prompt_rag_query_generation_fix_error, | |
get_prompt_detect_plugins, | |
get_prompt_rag_query_generation_technical, | |
get_prompt_rag_query_generation_vision_storyboard, | |
get_prompt_rag_query_generation_narration, | |
get_prompt_rag_query_generation_code | |
) | |
from src.rag.vector_store import RAGVectorStore | |
class RAGIntegration: | |
"""Class for integrating RAG (Retrieval Augmented Generation) functionality. | |
This class handles RAG integration including plugin detection, query generation, | |
and document retrieval. | |
Args: | |
helper_model: Model used for generating queries and processing text | |
output_dir (str): Directory for output files | |
chroma_db_path (str): Path to ChromaDB | |
manim_docs_path (str): Path to Manim documentation | |
embedding_model (str): Name of embedding model to use | |
use_langfuse (bool, optional): Whether to use Langfuse logging. Defaults to True | |
session_id (str, optional): Session identifier. Defaults to None | |
""" | |
def __init__(self, helper_model, output_dir, chroma_db_path, manim_docs_path, embedding_model, use_langfuse=True, session_id=None): | |
self.helper_model = helper_model | |
self.output_dir = output_dir | |
self.manim_docs_path = manim_docs_path | |
self.session_id = session_id | |
self.relevant_plugins = None | |
self.vector_store = RAGVectorStore( | |
chroma_db_path=chroma_db_path, | |
manim_docs_path=manim_docs_path, | |
embedding_model=embedding_model, | |
session_id=self.session_id, | |
use_langfuse=use_langfuse, | |
helper_model=helper_model | |
) | |
def set_relevant_plugins(self, plugins: List[str]) -> None: | |
"""Set the relevant plugins for the current video. | |
Args: | |
plugins (List[str]): List of plugin names to set as relevant | |
""" | |
self.relevant_plugins = plugins | |
def detect_relevant_plugins(self, topic: str, description: str) -> List[str]: | |
"""Detect which plugins might be relevant based on topic and description. | |
Args: | |
topic (str): Topic of the video | |
description (str): Description of the video content | |
Returns: | |
List[str]: List of detected relevant plugin names | |
""" | |
# Load plugin descriptions | |
plugins = self._load_plugin_descriptions() | |
if not plugins: | |
return [] | |
# Get formatted prompt using the task_generator function | |
prompt = get_prompt_detect_plugins( | |
topic=topic, | |
description=description, | |
plugin_descriptions=json.dumps([{'name': p['name'], 'description': p['description']} for p in plugins], indent=2) | |
) | |
try: | |
response = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "detect-relevant-plugins", "tags": [topic, "plugin-detection"], "session_id": self.session_id} | |
) | |
# Clean the response to ensure it only contains the JSON array | |
response = re.search(r'```json(.*)```', response, re.DOTALL).group(1) | |
try: | |
relevant_plugins = json.loads(response) | |
except json.JSONDecodeError as e: | |
print(f"JSONDecodeError when parsing relevant plugins: {e}") | |
print(f"Response text was: {response}") | |
return [] | |
print(f"LLM detected relevant plugins: {relevant_plugins}") | |
return relevant_plugins | |
except Exception as e: | |
print(f"Error detecting plugins with LLM: {e}") | |
return [] | |
def _load_plugin_descriptions(self) -> list: | |
"""Load plugin descriptions from JSON file. | |
Returns: | |
list: List of plugin descriptions, empty list if loading fails | |
""" | |
try: | |
plugin_config_path = os.path.join( | |
self.manim_docs_path, | |
"plugin_docs", | |
"plugins.json" | |
) | |
if os.path.exists(plugin_config_path): | |
with open(plugin_config_path, "r") as f: | |
return json.load(f) | |
else: | |
print(f"Plugin descriptions file not found at {plugin_config_path}") | |
return [] | |
except Exception as e: | |
print(f"Error loading plugin descriptions: {e}") | |
return [] | |
def _generate_rag_queries_storyboard(self, scene_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]: | |
"""Generate RAG queries from the scene plan to help create storyboard. | |
Args: | |
scene_plan (str): Scene plan text to generate queries from | |
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None | |
topic (str, optional): Topic name. Defaults to None | |
scene_number (int, optional): Scene number. Defaults to None | |
session_id (str, optional): Session identifier. Defaults to None | |
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list | |
Returns: | |
List[str]: List of generated RAG queries | |
""" | |
cache_key = f"{topic}_scene{scene_number}_storyboard_rag" | |
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file = os.path.join(cache_dir, "rag_queries_storyboard.json") | |
if os.path.exists(cache_file): | |
with open(cache_file, 'r') as f: | |
return json.load(f) | |
# Format relevant plugins as a string | |
plugins_str = ", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant." | |
# Generate the prompt with only the required arguments | |
prompt = get_prompt_rag_query_generation_vision_storyboard( | |
scene_plan=scene_plan, | |
relevant_plugins=plugins_str | |
) | |
queries = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "rag_query_generation_storyboard", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id} | |
) | |
# retreive json triple backticks | |
try: # add try-except block to handle potential json decode errors | |
queries = re.search(r'```json(.*)```', queries, re.DOTALL).group(1) | |
queries = json.loads(queries) | |
except json.JSONDecodeError as e: | |
print(f"JSONDecodeError when parsing RAG queries for storyboard: {e}") | |
print(f"Response text was: {queries}") | |
return [] # Return empty list in case of parsing error | |
# Cache the queries | |
with open(cache_file, 'w') as f: | |
json.dump(queries, f) | |
return queries | |
def _generate_rag_queries_technical(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]: | |
"""Generate RAG queries from the storyboard to help create technical implementation. | |
Args: | |
storyboard (str): Storyboard text to generate queries from | |
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None | |
topic (str, optional): Topic name. Defaults to None | |
scene_number (int, optional): Scene number. Defaults to None | |
session_id (str, optional): Session identifier. Defaults to None | |
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list | |
Returns: | |
List[str]: List of generated RAG queries | |
""" | |
cache_key = f"{topic}_scene{scene_number}_technical_rag" | |
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file = os.path.join(cache_dir, "rag_queries_technical.json") | |
if os.path.exists(cache_file): | |
with open(cache_file, 'r') as f: | |
return json.load(f) | |
prompt = get_prompt_rag_query_generation_technical( | |
storyboard=storyboard, | |
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant." | |
) | |
queries = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "rag_query_generation_technical", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id} | |
) | |
try: # add try-except block to handle potential json decode errors | |
queries = re.search(r'```json(.*)```', queries, re.DOTALL).group(1) | |
queries = json.loads(queries) | |
except json.JSONDecodeError as e: | |
print(f"JSONDecodeError when parsing RAG queries for technical implementation: {e}") | |
print(f"Response text was: {queries}") | |
return [] # Return empty list in case of parsing error | |
# Cache the queries | |
with open(cache_file, 'w') as f: | |
json.dump(queries, f) | |
return queries | |
def _generate_rag_queries_narration(self, storyboard: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None, relevant_plugins: List[str] = []) -> List[str]: | |
"""Generate RAG queries from the storyboard to help create narration plan. | |
Args: | |
storyboard (str): Storyboard text to generate queries from | |
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None | |
topic (str, optional): Topic name. Defaults to None | |
scene_number (int, optional): Scene number. Defaults to None | |
session_id (str, optional): Session identifier. Defaults to None | |
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to empty list | |
Returns: | |
List[str]: List of generated RAG queries | |
""" | |
cache_key = f"{topic}_scene{scene_number}_narration_rag" | |
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file = os.path.join(cache_dir, "rag_queries_narration.json") | |
if os.path.exists(cache_file): | |
with open(cache_file, 'r') as f: | |
return json.load(f) | |
prompt = get_prompt_rag_query_generation_narration( | |
storyboard=storyboard, | |
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant." | |
) | |
queries = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "rag_query_generation_narration", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id} | |
) | |
try: # add try-except block to handle potential json decode errors | |
queries = re.search(r'```json(.*)```', queries, re.DOTALL).group(1) | |
queries = json.loads(queries) | |
except json.JSONDecodeError as e: | |
print(f"JSONDecodeError when parsing narration RAG queries: {e}") | |
print(f"Response text was: {queries}") | |
return [] # Return empty list in case of parsing error | |
# Cache the queries | |
with open(cache_file, 'w') as f: | |
json.dump(queries, f) | |
return queries | |
def get_relevant_docs(self, rag_queries: List[Dict], scene_trace_id: str, topic: str, scene_number: int) -> List[str]: | |
"""Get relevant documentation using the vector store. | |
Args: | |
rag_queries (List[Dict]): List of RAG queries to search for | |
scene_trace_id (str): Trace identifier for the scene | |
topic (str): Topic name | |
scene_number (int): Scene number | |
Returns: | |
List[str]: List of relevant documentation snippets | |
""" | |
return self.vector_store.find_relevant_docs( | |
queries=rag_queries, | |
k=2, | |
trace_id=scene_trace_id, | |
topic=topic, | |
scene_number=scene_number | |
) | |
def _generate_rag_queries_code(self, implementation_plan: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, relevant_plugins: List[str] = None) -> List[str]: | |
"""Generate RAG queries from implementation plan. | |
Args: | |
implementation_plan (str): Implementation plan text to generate queries from | |
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None | |
topic (str, optional): Topic name. Defaults to None | |
scene_number (int, optional): Scene number. Defaults to None | |
relevant_plugins (List[str], optional): List of relevant plugins. Defaults to None | |
Returns: | |
List[str]: List of generated RAG queries | |
""" | |
cache_key = f"{topic}_scene{scene_number}" | |
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file = os.path.join(cache_dir, "rag_queries_code.json") | |
if os.path.exists(cache_file): | |
with open(cache_file, 'r') as f: | |
return json.load(f) | |
prompt = get_prompt_rag_query_generation_code( | |
implementation_plan=implementation_plan, | |
relevant_plugins=", ".join(relevant_plugins) if relevant_plugins else "No plugins are relevant." | |
) | |
try: | |
response = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "rag_query_generation_code", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": self.session_id} | |
) | |
# Clean and parse response | |
response = re.search(r'```json(.*)```', response, re.DOTALL).group(1) | |
queries = json.loads(response) | |
# Cache the queries | |
with open(cache_file, 'w') as f: | |
json.dump(queries, f) | |
return queries | |
except Exception as e: | |
print(f"Error generating RAG queries: {e}") | |
return [] | |
def _generate_rag_queries_error_fix(self, error: str, code: str, scene_trace_id: str = None, topic: str = None, scene_number: int = None, session_id: str = None) -> List[str]: | |
"""Generate RAG queries for fixing code errors. | |
Args: | |
error (str): Error message to generate queries from | |
code (str): Code containing the error | |
scene_trace_id (str, optional): Trace identifier for the scene. Defaults to None | |
topic (str, optional): Topic name. Defaults to None | |
scene_number (int, optional): Scene number. Defaults to None | |
session_id (str, optional): Session identifier. Defaults to None | |
Returns: | |
List[str]: List of generated RAG queries | |
""" | |
if self.relevant_plugins is None: | |
print("Warning: No plugins have been detected yet") | |
plugins_str = "No plugins are relevant." | |
else: | |
plugins_str = ", ".join(self.relevant_plugins) if self.relevant_plugins else "No plugins are relevant." | |
cache_key = f"{topic}_scene{scene_number}_error_fix" | |
cache_dir = os.path.join(self.output_dir, re.sub(r'[^a-z0-9_]+', '_', topic.lower()), f"scene{scene_number}", "rag_cache") | |
os.makedirs(cache_dir, exist_ok=True) | |
cache_file = os.path.join(cache_dir, "rag_queries_error_fix.json") | |
if os.path.exists(cache_file): | |
with open(cache_file, 'r') as f: | |
cached_queries = json.load(f) | |
print(f"Using cached RAG queries for error fix in {cache_key}") | |
return cached_queries | |
prompt = get_prompt_rag_query_generation_fix_error( | |
error=error, | |
code=code, | |
relevant_plugins=plugins_str | |
) | |
queries = self.helper_model( | |
_prepare_text_inputs(prompt), | |
metadata={"generation_name": "rag-query-generation-fix-error", "trace_id": scene_trace_id, "tags": [topic, f"scene{scene_number}"], "session_id": session_id} | |
) | |
try: | |
# retrieve json triple backticks | |
queries = re.search(r'```json(.*)```', queries, re.DOTALL).group(1) | |
queries = json.loads(queries) | |
except json.JSONDecodeError as e: | |
print(f"JSONDecodeError when parsing RAG queries for error fix: {e}") | |
print(f"Response text was: {queries}") | |
return [] | |
# Cache the queries | |
with open(cache_file, 'w') as f: | |
json.dump(queries, f) | |
return queries |