Spaces:
Running
Running
from typing import List, Dict, Any | |
import zipfile | |
import os | |
import warnings | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
import bm25s | |
from fastapi.staticfiles import StaticFiles | |
from nltk.stem import WordNetLemmatizer | |
import nltk | |
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import numpy as np | |
from pydantic import BaseModel | |
from sklearn.preprocessing import MinMaxScaler | |
load_dotenv() | |
nltk.download('wordnet') | |
if os.path.exists("bm25s.zip"): | |
with zipfile.ZipFile("bm25s.zip", 'r') as zip_ref: | |
zip_ref.extractall(".") | |
bm25_engine = bm25s.BM25.load("3gpp_bm25_docs", load_corpus=True) | |
lemmatizer = WordNetLemmatizer() | |
llm = OpenAI(api_key=os.environ.get("GEMINI"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/") | |
warnings.filterwarnings("ignore") | |
app = FastAPI(title="RAGnarok", | |
description="API to search specifications for RAG") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
origins = [ | |
"*", | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class SearchRequest(BaseModel): | |
keyword: str | |
threshold: int | |
class SearchResponse(BaseModel): | |
results: List[Dict[str, Any]] | |
class ChatRequest(BaseModel): | |
messages: List[Dict[str, str]] | |
model: str | |
class ChatResponse(BaseModel): | |
response: str | |
async def main_menu(): | |
return FileResponse(os.path.join("templates", "index.html")) | |
def question_the_sources(req: ChatRequest): | |
model = req.model | |
resp = llm.chat.completions.create( | |
messages=req.messages, | |
model=model | |
) | |
return ChatResponse(response=resp.choices[0].message.content) | |
def search_specifications(req: SearchRequest): | |
keywords = req.keyword | |
threshold = req.threshold | |
query = lemmatizer.lemmatize(keywords) | |
results_out = [] | |
query_tokens = bm25s.tokenize(query) | |
results, scores = bm25_engine.retrieve(query_tokens, k=len(bm25_engine.corpus)) | |
def calculate_boosted_score(metadata, score, query): | |
title = {lemmatizer.lemmatize(metadata['title']).lower()} | |
q = {query.lower()} | |
spec_id_presence = 0.5 if len(q & {metadata['id']}) > 0 else 0 | |
booster = len(q & title) * 0.5 | |
return score + spec_id_presence + booster | |
spec_scores = {} | |
spec_indices = {} | |
spec_details = {} | |
for i in range(results.shape[1]): | |
doc = results[0, i] | |
score = scores[0, i] | |
spec = doc["metadata"]["id"] | |
boosted_score = calculate_boosted_score(doc['metadata'], score, query) | |
if spec not in spec_scores or boosted_score > spec_scores[spec]: | |
spec_scores[spec] = boosted_score | |
spec_indices[spec] = i | |
spec_details[spec] = { | |
'original_score': score, | |
'boosted_score': boosted_score, | |
'doc': doc | |
} | |
def normalize_scores(scores_dict): | |
if not scores_dict: | |
return {} | |
scores_array = np.array(list(scores_dict.values())).reshape(-1, 1) | |
scaler = MinMaxScaler() | |
normalized_scores = scaler.fit_transform(scores_array).flatten() | |
normalized_dict = {} | |
for i, spec in enumerate(scores_dict.keys()): | |
normalized_dict[spec] = normalized_scores[i] | |
return normalized_dict | |
normalized_scores = normalize_scores(spec_scores) | |
for spec in spec_details: | |
spec_details[spec]["normalized_score"] = normalized_scores[spec] | |
unique_specs = sorted(normalized_scores.keys(), key=lambda x: normalized_scores[x], reverse=True) | |
for rank, spec in enumerate(unique_specs, 1): | |
details = spec_details[spec] | |
metadata = details['doc']['metadata'] | |
if details['normalized_score'] < threshold / 100: | |
break | |
results_out.append({'id': metadata['id'], 'title': metadata['title'], 'section': metadata['section_title'], 'content': details['doc']['text'], 'similarity': int(details['normalized_score']*100)}) | |
return SearchResponse(results=results_out) |