sasagema's picture
Added files
85c1145
raw
history blame
1.42 kB
import json
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from promptSearchEngine import PromptSearchEngine
from vectorizer import Vectorizer
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
DATASET = "Gustavosta/Stable-Diffusion-Prompts"
model = SentenceTransformer(EMBEDDING_MODEL)
dataset = load_dataset(DATASET , split="test[:1%]")
promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model)
class SearchRequest(BaseModel):
query: str
n: int | None = 5
app = FastAPI()
@app.get("/")
async def root():
return {"message": 'GET /docs'}
@app.get("/search")
async def search(q: str, n: int = 5):
results = []
if q.isspace() or q =="":
return {"message": "Enter query"}
else:
results = promptSearchEngine.most_similar(q, n)
if not results:
raise HTTPException(status_code=404, detail="No prompts found.")
return promptSearchEngine.stringify_prompts(results)
@app.post("/search")
async def searchPost(request: SearchRequest):
results = promptSearchEngine.most_similar(request.query, request.n)
if not results:
raise HTTPException(status_code=404, detail="No prompts found.")
formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results]
return { "data" : formatted_results }