|
import os |
|
import json |
|
from typing import Any, Tuple |
|
from deepeval.models.base_model import DeepEvalBaseLLM |
|
from src.evaluation.writer.agent_write import create_workflow_sync, create_workflow_async |
|
from src.utils.api_key_manager import with_api_manager |
|
from src.helpers.helper import remove_markdown |
|
from dotenv import load_dotenv |
|
|
|
class LangChainWrapper(DeepEvalBaseLLM): |
|
def __init__(self): |
|
|
|
load_dotenv() |
|
|
|
|
|
self.model_name = os.getenv("MODEL_NAME") |
|
|
|
|
|
def _invoke_llm_sync(self, prompt: Any) -> Tuple[str, float]: |
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
def _inner_invoke_sync(*args, **kwargs): |
|
response = kwargs['llm'].invoke(prompt) |
|
raw_text = response.content.strip() |
|
return raw_text |
|
|
|
raw_text = _inner_invoke_sync() |
|
return raw_text |
|
|
|
|
|
async def _invoke_llm_async(self, prompt: Any) -> Tuple[str, float]: |
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
async def _inner_invoke_async(*args, **kwargs): |
|
response = await kwargs['llm'].ainvoke(prompt) |
|
raw_text = response.content.strip() |
|
return raw_text |
|
|
|
raw_text = await _inner_invoke_async() |
|
return raw_text |
|
|
|
|
|
def _parse_as_schema(self, raw_text: str, schema: Any) -> Any: |
|
cleaned_text = remove_markdown(raw_text) |
|
data = json.loads(cleaned_text) |
|
|
|
|
|
try: |
|
return schema(**data) |
|
except Exception: |
|
print(f"Failed to parse data for schema: {schema}") |
|
raise |
|
|
|
|
|
def generate(self, prompt: Any, schema: Any = None) -> str: |
|
raw_text = self._invoke_llm_sync(prompt) |
|
|
|
if schema is not None: |
|
try: |
|
parsed_obj = self._parse_as_schema(raw_text, schema) |
|
return parsed_obj |
|
except json.JSONDecodeError as e: |
|
print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...") |
|
|
|
input = { |
|
"initial_prompt": prompt, |
|
"plan": "", |
|
"write_steps": [], |
|
"final_json": "" |
|
} |
|
app = create_workflow_sync() |
|
final_state = app.invoke(input) |
|
output = remove_markdown(final_state['final_json']) |
|
|
|
try: |
|
data = json.loads(output) |
|
return data |
|
except json.JSONDecodeError as e: |
|
raise Exception(f"Cannot parse JSON data: {e}") |
|
else: |
|
return raw_text |
|
|
|
|
|
async def a_generate(self, prompt: Any, schema: Any = None) -> str: |
|
raw_text = await self._invoke_llm_async(prompt) |
|
|
|
if schema is not None: |
|
try: |
|
parsed_obj = self._parse_as_schema(raw_text, schema) |
|
return parsed_obj |
|
except json.JSONDecodeError as e: |
|
print(f"Failed to parse JSON data: {e}\nUsing LangGraph fallback...") |
|
|
|
input = { |
|
"initial_prompt": prompt, |
|
"plan": "", |
|
"write_steps": [], |
|
"final_json": "" |
|
} |
|
app = create_workflow_async() |
|
final_state = await app.ainvoke(input) |
|
output = remove_markdown(final_state['final_json']) |
|
|
|
try: |
|
data = json.loads(output) |
|
return data |
|
except json.JSONDecodeError as e: |
|
raise Exception(f"Cannot parse JSON data: {e}") |
|
else: |
|
return raw_text |
|
|
|
|
|
def get_model_name(self) -> str: |
|
return f"LangChainWrapper for {self.model_name}" |
|
|
|
|
|
def load_model(self, *, llm: Any): |
|
@with_api_manager(temperature=0.0, top_p=1.0) |
|
def inner_load_model(*args, **kwargs): |
|
return llm |
|
|
|
return inner_load_model() |