general-api / app.py
sarthak501's picture
Update app.py
7d1e868 verified
raw
history blame
1.14 kB
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"]
)
model_name = "togethercomputer/RedPajama-INCITE-7B-Base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
class PromptRequest(BaseModel):
prompt: str
@app.post("/api/generate-story")
async def generate_story(req: PromptRequest):
prompt = req.prompt.strip()
if not prompt:
raise HTTPException(status_code=400, detail="Prompt must not be empty")
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
top_p=0.9,
temperature=0.85,
repetition_penalty=1.2
)
story = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"story": story}