Spaces:
Sleeping
Sleeping
import json | |
import requests | |
from typing import List, Dict, Any, Optional | |
from config.settings import Config | |
class LLMExtractor: | |
def __init__(self): | |
self.config = Config() | |
self.headers = { | |
"Authorization": f"Bearer {self.config.OPENROUTER_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
def extract_entities_and_relationships(self, text: str) -> Dict[str, Any]: | |
"""Extract entities and relationships from text using LLM.""" | |
prompt = self._create_extraction_prompt(text) | |
try: | |
response = self._call_openrouter_api(prompt, self.config.EXTRACTION_MODEL) | |
result = self._parse_extraction_response(response) | |
return result | |
except Exception as e: | |
# Try backup model | |
try: | |
response = self._call_openrouter_api(prompt, self.config.BACKUP_MODEL) | |
result = self._parse_extraction_response(response) | |
return result | |
except Exception as backup_e: | |
return { | |
"entities": [], | |
"relationships": [], | |
"error": f"Primary: {str(e)}, Backup: {str(backup_e)}" | |
} | |
def _create_extraction_prompt(self, text: str) -> str: | |
"""Create prompt for entity and relationship extraction.""" | |
return f""" | |
You are an expert knowledge graph extraction system. Analyze the following text and extract: | |
1. ENTITIES: Important people, organizations, locations, concepts, events, objects, etc. | |
2. RELATIONSHIPS: How these entities relate to each other | |
3. IMPORTANCE SCORES: Rate each entity's importance from 0.0 to 1.0 based on how central it is to the text | |
For each entity, provide: | |
- name: The entity name (standardized/canonical form) | |
- type: The entity type (PERSON, ORGANIZATION, LOCATION, CONCEPT, EVENT, OBJECT, etc.) | |
- importance: Score from 0.0 to 1.0 | |
- description: Brief description of the entity's role/significance | |
For each relationship, provide: | |
- source: Source entity name | |
- target: Target entity name | |
- relationship: Type of relationship (works_at, located_in, part_of, causes, etc.) | |
- description: Brief description of the relationship | |
Only respond with a valid JSON object with this structure and nothing else. Your response must be valid, parsable JSON!! | |
=== JSON STRUCTURE FOR RESPONSE / RESPONSE FORMAT === | |
{{ | |
"entities": [ | |
{{ | |
"name": "entity_name", | |
"type": "ENTITY_TYPE", | |
"importance": 0.8, | |
"description": "Brief description" | |
}} | |
], | |
"relationships": [ | |
{{ | |
"source": "entity1", | |
"target": "entity2", | |
"relationship": "relationship_type", | |
"description": "Brief description" | |
}} | |
] | |
}} | |
=== END OF JSON STRUCTURE FOR RESPONSE / END OF RESPONSE FORMAT === | |
TEXT TO ANALYZE: | |
{text} | |
Reply in valid json using the format above! | |
JSON OUTPUT: | |
""" | |
def _call_openrouter_api(self, prompt: str, model: str) -> str: | |
"""Make API call to OpenRouter.""" | |
if not self.config.OPENROUTER_API_KEY: | |
raise ValueError("OpenRouter API key not configured") | |
payload = { | |
"model": model, | |
"messages": [ | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
], | |
"max_tokens": 2048, | |
"temperature": 0.1 | |
} | |
response = requests.post( | |
f"{self.config.OPENROUTER_BASE_URL}/chat/completions", | |
headers=self.headers, | |
json=payload, | |
timeout=60 | |
) | |
if response.status_code != 200: | |
raise Exception(f"API call failed: {response.status_code} - {response.text}") | |
result = response.json() | |
if "choices" not in result or not result["choices"]: | |
raise Exception("Invalid API response format") | |
return result["choices"][0]["message"]["content"] | |
def _parse_extraction_response(self, response: str) -> Dict[str, Any]: | |
"""Parse the LLM response into structured data.""" | |
try: | |
# Try to find JSON in the response | |
start_idx = response.find("{") | |
end_idx = response.rfind("}") + 1 | |
if start_idx == -1 or end_idx == 0: | |
raise ValueError("No JSON found in response") | |
json_str = response[start_idx:end_idx] | |
data = json.loads(json_str) | |
# Validate structure | |
if "entities" not in data: | |
data["entities"] = [] | |
if "relationships" not in data: | |
data["relationships"] = [] | |
# Filter entities by importance threshold | |
filtered_entities = [ | |
entity for entity in data["entities"] | |
if entity.get("importance", 0) >= self.config.ENTITY_IMPORTANCE_THRESHOLD | |
] | |
# Limit number of entities and relationships | |
data["entities"] = filtered_entities[:self.config.MAX_ENTITIES] | |
data["relationships"] = data["relationships"][:self.config.MAX_RELATIONSHIPS] | |
return data | |
except json.JSONDecodeError as e: | |
return { | |
"entities": [], | |
"relationships": [], | |
"error": f"JSON parsing error: {str(e)}" | |
} | |
except Exception as e: | |
return { | |
"entities": [], | |
"relationships": [], | |
"error": f"Response parsing error: {str(e)}" | |
} | |
def process_chunks(self, chunks: List[str]) -> Dict[str, Any]: | |
"""Process multiple text chunks and combine results.""" | |
all_entities = [] | |
all_relationships = [] | |
errors = [] | |
for i, chunk in enumerate(chunks): | |
try: | |
result = self.extract_entities_and_relationships(chunk) | |
if "error" in result: | |
errors.append(f"Chunk {i+1}: {result['error']}") | |
continue | |
all_entities.extend(result.get("entities", [])) | |
all_relationships.extend(result.get("relationships", [])) | |
except Exception as e: | |
errors.append(f"Chunk {i+1}: {str(e)}") | |
# Deduplicate and standardize entities | |
unique_entities = self._deduplicate_entities(all_entities) | |
# Validate relationships against existing entities | |
valid_relationships = self._validate_relationships(all_relationships, unique_entities) | |
return { | |
"entities": unique_entities, | |
"relationships": valid_relationships, | |
"errors": errors if errors else None | |
} | |
def _deduplicate_entities(self, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
"""Remove duplicate entities and merge similar ones.""" | |
seen_names = set() | |
unique_entities = [] | |
for entity in entities: | |
name = entity.get("name", "").lower().strip() | |
if name and name not in seen_names: | |
seen_names.add(name) | |
unique_entities.append(entity) | |
# Sort by importance | |
unique_entities.sort(key=lambda x: x.get("importance", 0), reverse=True) | |
return unique_entities[:self.config.MAX_ENTITIES] | |
def _validate_relationships(self, relationships: List[Dict[str, Any]], entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
"""Validate that relationships reference existing entities.""" | |
entity_names = {entity.get("name", "").lower() for entity in entities} | |
valid_relationships = [] | |
for rel in relationships: | |
source = rel.get("source", "").lower() | |
target = rel.get("target", "").lower() | |
if source in entity_names and target in entity_names: | |
valid_relationships.append(rel) | |
return valid_relationships[:self.config.MAX_RELATIONSHIPS] | |