|
from fastapi import FastAPI, HTTPException, Depends, status |
|
from fastapi.responses import FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from pydantic import BaseModel |
|
from jose import JWTError, jwt |
|
from datetime import datetime, timedelta |
|
from openai import OpenAI |
|
from typing import List |
|
import pandas as pd |
|
import os |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
SECRET_KEY = os.environ.get("prime_auth", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be") |
|
REFRESH_SECRET_KEY = os.environ.get("prolonged_auth", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91") |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
REFRESH_TOKEN_EXPIRE_DAYS = 7 |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") |
|
|
|
|
|
def load_credentials(): |
|
credentials = {} |
|
for i in range(1, 51): |
|
username = os.environ.get(f"login_{i}") |
|
password = os.environ.get(f"password_{i}") |
|
if username and password: |
|
credentials[username] = password |
|
return credentials |
|
|
|
|
|
def authenticate_user(username: str, password: str): |
|
credentials_dict = load_credentials() |
|
if username in credentials_dict and credentials_dict[username] == password: |
|
return username |
|
return None |
|
|
|
|
|
def create_token(data: dict, expires_delta: timedelta, secret_key: str): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
|
|
def verify_token(token: str, secret_key: str): |
|
credentials_exception = HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise credentials_exception |
|
except JWTError: |
|
raise credentials_exception |
|
return username |
|
|
|
|
|
def verify_access_token(token: str = Depends(oauth2_scheme)): |
|
return verify_token(token, SECRET_KEY) |
|
|
|
|
|
def verify_refresh_token(token: str): |
|
return verify_token(token, REFRESH_SECRET_KEY) |
|
|
|
|
|
def load_data(database_file): |
|
df = pd.read_parquet(database_file) |
|
|
|
return df |
|
|
|
|
|
def generate_openai_embeddings(client, text): |
|
response = client.embeddings.create( |
|
input=text, |
|
model="text-embedding-3-small" |
|
) |
|
return response.data[0].embedding |
|
|
|
|
|
def cosine_similarity(embedding_0, embedding_1): |
|
dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1)) |
|
norm_0 = sum(a * a for a in embedding_0) ** 0.5 |
|
norm_1 = sum(b * b for b in embedding_1) ** 0.5 |
|
return dot_product / (norm_0 * norm_1) |
|
|
|
|
|
def search_query(client, query, df, n=3): |
|
embedding = generate_openai_embeddings(client, query) |
|
df['similarities'] = df.openai_embedding.apply(lambda x: cosine_similarity(x, embedding)) |
|
res = df.sort_values('similarities', ascending=False).head(n) |
|
return res |
|
|
|
|
|
class QueryInput(BaseModel): |
|
query: str |
|
|
|
|
|
class SearchResult(BaseModel): |
|
text: str |
|
similarity: float |
|
|
|
|
|
class TokenResponse(BaseModel): |
|
access_token: str |
|
refresh_token: str |
|
token_type: str |
|
|
|
|
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
return FileResponse(path="static/index.html", media_type="text/html") |
|
|
|
|
|
@app.post("/login", response_model=TokenResponse) |
|
def login(form_data: OAuth2PasswordRequestForm = Depends()): |
|
logging.info("Login attempt for user: %s", form_data.username) |
|
username = authenticate_user(form_data.username, form_data.password) |
|
if not username: |
|
logging.warning("Authentication failed for user: %s", form_data.username) |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) |
|
access_token = create_token(data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY) |
|
refresh_token = create_token(data={"sub": username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY) |
|
logging.info("Tokens issued for user: %s", username) |
|
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} |
|
|
|
|
|
@app.post("/refresh", response_model=TokenResponse) |
|
def refresh(refresh_token: str): |
|
username = verify_refresh_token(refresh_token) |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
access_token = create_token(data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY) |
|
return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"} |
|
|
|
|
|
@app.post("/search", response_model=List[SearchResult]) |
|
def search( |
|
query_input: QueryInput, |
|
username: str = Depends(verify_access_token), |
|
): |
|
|
|
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
|
|
|
|
|
database_file = "/[openai_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet" |
|
df = load_data(database_file) |
|
logging.info("Database loaded successfully") |
|
|
|
|
|
res = search_query(client, query_input.query, df, n=3) |
|
|
|
|
|
results = [ |
|
SearchResult(text=row["ext"], similarity=row["similarities"]) |
|
for _, row in res.iterrows() |
|
] |
|
|
|
return results |
|
|
|
app.mount("/home", StaticFiles(directory="static", html=True), name="static") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |