Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
import os | |
# Configure the endpoint and authentication | |
ENDPOINT_URL = os.environ.get("ENDPOINT_URL", "https://dz0eq6vxq3nm0uh7.us-east-1.aws.endpoints.huggingface.cloud") | |
# HF_API_TOKEN = os.environ.get("HF_API_TOKEN") # Get API token from environment variable | |
HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "").strip() # Use strip() to remove extra whitespaces and newlines | |
# Check if the API token is configured | |
def is_token_configured(): | |
if not HF_API_TOKEN: | |
return "β οΈ Warning: HF_API_TOKEN is not configured. The app won't work until you add this secret in your Space settings." | |
return "β API token is configured" | |
import requests | |
import json | |
import requests | |
def check_safety(input_text): | |
if not input_text.strip(): | |
return "β οΈ Please enter some text to check." | |
payload = { | |
"inputs": input_text | |
} | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {HF_API_TOKEN}" | |
} | |
try: | |
response = requests.post(ENDPOINT_URL, json=payload, headers=headers, timeout=30) | |
if response.headers.get("content-type", "").startswith("application/json"): | |
result = response.json() # result is a string containing triple backticks | |
if isinstance(result, str): | |
# Remove triple backticks if present | |
cleaned = result.strip() | |
if cleaned.startswith("```"): | |
cleaned = cleaned.strip("```").strip() | |
if cleaned.startswith("json"): | |
cleaned = cleaned[4:].strip() # remove 'json' label if there | |
# Now parse cleaned string | |
result = json.loads(cleaned) | |
# Now safely access fields | |
is_safe = result.get("Safety", "").lower() == "safe" | |
score = result.get("Score", "") | |
categories = result.get("Unsafe Categories", "") | |
if is_safe: | |
return f"β Safe\n\nSafety: safe\nScore: {score}\nUnsafe Categories: {categories}" | |
else: | |
return f"β Unsafe\n\nSafety: unsafe\nScore: {score}\nUnsafe Categories: {categories}" | |
else: | |
return f"β Error: Server returned non-JSON response:\n\n{response.text}" | |
except requests.exceptions.Timeout: | |
return "β Error: Request timed out." | |
except requests.exceptions.ConnectionError: | |
return "β Error: Failed to connect to the endpoint." | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
with gr.Blocks(title="Safety Content Classifier", css="footer {display: none !important}") as demo: | |
gr.Markdown(f"# Safety Content Classifier") | |
gr.Markdown(f"## Connected to external safety model endpoint") | |
token_status = gr.Markdown(is_token_configured()) | |
with gr.Accordion("About this demo", open=False): | |
gr.Markdown(""" | |
This demo uses an external API endpoint to classify text and images based on safety policies. | |
It checks content against the following categories: | |
- Harassment | |
- Dangerous Content | |
- Hate Speech | |
- Sexually Explicit Information | |
The model will respond with 'Safe' or 'Unsafe' followed by any violated categories. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Enter text to check", | |
placeholder="Type here...", | |
lines=5 | |
) | |
check_button = gr.Button("Check Safety", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Safety Result", | |
lines=8 | |
) | |
# Update event handlers | |
check_button.click(fn=check_safety, inputs=[input_text], outputs=output) | |
input_text.submit(fn=check_safety, inputs=[input_text], outputs=output) | |
gr.Examples( | |
[ | |
["Hello, how are you today?", None], | |
["I want to learn how to make a bomb.", None], | |
["", "example_images/unsafe_meme.png"], | |
["", "example_images/normal_photo.jpg"], | |
], | |
inputs=[input_text] | |
) | |
demo.launch() | |