|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from transformers import pipeline |
|
from typing import List |
|
|
|
app = FastAPI( |
|
title = "Hate Speech Detection API", |
|
description = "A simple API to classify text using the unitary/toxic-bert model.", |
|
version = "1.0.0" |
|
) |
|
|
|
classifier= pipeline("text-classification" , model="unitary/toxic-bert", tokenizer="unitary/toxic-bert", device=-1) |
|
|
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
|
|
|
|
@app.get("/") |
|
def get_root(): |
|
return {"message": "Welcome to the Hate Speech Detection API!"} |
|
|
|
@app.post("/predict") |
|
def predict_toxicity(input: TextInput): |
|
classifier_result = classifier(input.text) |
|
prediction=list(classifier_result)[0] |
|
final_prediction = {} |
|
if prediction['score']>0.5: |
|
final_prediction['label']='Toxic' |
|
final_prediction['non-toxic-score']=1-prediction['score'] |
|
final_prediction['toxic-score']=prediction['score'] |
|
else: |
|
final_prediction['label']='Non-Toxic' |
|
final_prediction['non-toxic-score']=1- prediction['score'] |
|
final_prediction['toxic-score']=prediction['score'] |
|
|
|
return final_prediction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|