Update app.py
Browse files
app.py
CHANGED
@@ -2,251 +2,643 @@ import gradio as gr
|
|
2 |
import os
|
3 |
import json
|
4 |
import logging
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
try:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
def
|
55 |
-
|
56 |
-
|
|
|
|
|
57 |
try:
|
58 |
-
messages = json.loads(
|
59 |
-
if not isinstance(messages, list):
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
except Exception as e:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
if "[MASK]" not in text:
|
76 |
-
return "Your input must contain the token `[MASK]`."
|
77 |
-
return safe_call(
|
78 |
-
CLIENT.fill_mask,
|
79 |
text,
|
80 |
-
model=
|
81 |
-
|
82 |
)
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if not question or not context:
|
87 |
-
return "Both question and context are required."
|
88 |
-
|
89 |
-
|
90 |
-
question
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
94 |
)
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
)
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if not prompt.strip():
|
110 |
-
return "Prompt cannot be empty."
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
image_path,
|
126 |
-
model=
|
127 |
-
|
128 |
)
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
if not text.strip():
|
133 |
-
return "Input text
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
139 |
)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import os
|
3 |
import json
|
4 |
import logging
|
5 |
+
import asyncio
|
6 |
+
from typing import Dict, List, Any, Optional, Union
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from enum import Enum
|
9 |
+
import tempfile
|
10 |
+
import base64
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
try:
|
14 |
+
from huggingface_hub import InferenceClient
|
15 |
+
from huggingface_hub.inference._client import InferenceApiError
|
16 |
+
except ImportError:
|
17 |
+
raise ImportError("Please install huggingface_hub: pip install huggingface_hub")
|
18 |
+
|
19 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
20 |
+
# ๐๏ธ ARCHITECTURE & CONFIGURATION
|
21 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
22 |
+
|
23 |
+
class TaskType(Enum):
|
24 |
+
"""Enumeration of supported AI tasks."""
|
25 |
+
ASR = "automatic_speech_recognition"
|
26 |
+
CHAT = "chat_completion"
|
27 |
+
FILL_MASK = "fill_mask"
|
28 |
+
QA = "question_answering"
|
29 |
+
SUMMARIZATION = "summarization"
|
30 |
+
TEXT_GENERATION = "text_generation"
|
31 |
+
IMAGE_CLASSIFICATION = "image_classification"
|
32 |
+
FEATURE_EXTRACTION = "feature_extraction"
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class ModelConfig:
|
36 |
+
"""Configuration for AI models."""
|
37 |
+
name: str
|
38 |
+
task_type: TaskType
|
39 |
+
timeout: int = 45
|
40 |
+
max_retries: int = 3
|
41 |
+
|
42 |
+
# ๐ฏ Model Registry - Easily extensible for different models
|
43 |
+
MODEL_REGISTRY = {
|
44 |
+
TaskType.ASR: ModelConfig("openai/whisper-large-v3", TaskType.ASR),
|
45 |
+
TaskType.CHAT: ModelConfig("microsoft/DialoGPT-medium", TaskType.CHAT),
|
46 |
+
TaskType.FILL_MASK: ModelConfig("google-bert/bert-base-uncased", TaskType.FILL_MASK),
|
47 |
+
TaskType.QA: ModelConfig("deepset/roberta-base-squad2", TaskType.QA),
|
48 |
+
TaskType.SUMMARIZATION: ModelConfig("facebook/bart-large-cnn", TaskType.SUMMARIZATION),
|
49 |
+
TaskType.TEXT_GENERATION: ModelConfig("gpt2", TaskType.TEXT_GENERATION),
|
50 |
+
TaskType.IMAGE_CLASSIFICATION: ModelConfig("google/vit-base-patch16-224", TaskType.IMAGE_CLASSIFICATION),
|
51 |
+
TaskType.FEATURE_EXTRACTION: ModelConfig("sentence-transformers/all-MiniLM-L6-v2", TaskType.FEATURE_EXTRACTION),
|
52 |
+
}
|
53 |
+
|
54 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
55 |
+
# ๐ง INFRASTRUCTURE & UTILITIES
|
56 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
57 |
+
|
58 |
+
class AIInferenceEngine:
|
59 |
+
"""
|
60 |
+
High-performance AI inference engine with robust error handling,
|
61 |
+
async support, and intelligent retry mechanisms.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(self):
|
65 |
+
self.api_token = os.getenv("HF_API_TOKEN")
|
66 |
+
if not self.api_token:
|
67 |
+
# For demo purposes, we'll work without token but with limitations
|
68 |
+
logging.warning("HF_API_TOKEN not set. Some features may be limited.")
|
69 |
+
self.client = None
|
70 |
+
else:
|
71 |
+
self.client = InferenceClient(api_key=self.api_token)
|
72 |
+
|
73 |
+
self._setup_logging()
|
74 |
+
|
75 |
+
def _setup_logging(self):
|
76 |
+
"""Configure structured logging for better debugging."""
|
77 |
+
logging.basicConfig(
|
78 |
+
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
79 |
+
level=logging.INFO,
|
80 |
+
datefmt="%Y-%m-%d %H:%M:%S"
|
81 |
+
)
|
82 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
83 |
+
|
84 |
+
async def _safe_inference_call(self, func, *args, **kwargs) -> Dict[str, Any]:
|
85 |
+
"""
|
86 |
+
Execute inference calls with comprehensive error handling and retries.
|
87 |
+
"""
|
88 |
+
if not self.client:
|
89 |
+
return {
|
90 |
+
"error": "API token not configured. Please set HF_API_TOKEN environment variable.",
|
91 |
+
"success": False
|
92 |
+
}
|
93 |
+
|
94 |
+
model_config = kwargs.pop('model_config', None)
|
95 |
+
max_retries = model_config.max_retries if model_config else 3
|
96 |
+
|
97 |
+
for attempt in range(max_retries):
|
98 |
+
try:
|
99 |
+
# Add timeout to prevent hanging
|
100 |
+
if model_config:
|
101 |
+
kwargs['timeout'] = model_config.timeout
|
102 |
+
|
103 |
+
result = await asyncio.to_thread(func, *args, **kwargs)
|
104 |
+
|
105 |
+
# Normalize response format
|
106 |
+
if isinstance(result, dict) and "error" in result:
|
107 |
+
return {"error": result["error"], "success": False}
|
108 |
+
|
109 |
+
return {"data": result, "success": True}
|
110 |
+
|
111 |
+
except InferenceApiError as e:
|
112 |
+
self.logger.error(f"API Error (attempt {attempt + 1}): {e}")
|
113 |
+
if attempt == max_retries - 1:
|
114 |
+
return {"error": f"API Error: {str(e)}", "success": False}
|
115 |
+
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
116 |
+
|
117 |
+
except Exception as e:
|
118 |
+
self.logger.exception(f"Unexpected error (attempt {attempt + 1}): {e}")
|
119 |
+
if attempt == max_retries - 1:
|
120 |
+
return {"error": f"Unexpected error: {str(e)}", "success": False}
|
121 |
+
await asyncio.sleep(1)
|
122 |
+
|
123 |
+
return {"error": "Max retries exceeded", "success": False}
|
124 |
+
|
125 |
+
# Global inference engine instance
|
126 |
+
inference_engine = AIInferenceEngine()
|
127 |
+
|
128 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
129 |
+
# ๐ค TASK IMPLEMENTATIONS
|
130 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
131 |
+
|
132 |
+
async def process_speech_recognition(audio_data) -> str:
|
133 |
+
"""Process audio file for speech recognition."""
|
134 |
+
if audio_data is None:
|
135 |
+
return "โ ๏ธ Please upload an audio file or record audio."
|
136 |
+
|
137 |
try:
|
138 |
+
# Handle different audio input types
|
139 |
+
if isinstance(audio_data, tuple):
|
140 |
+
sample_rate, audio_array = audio_data
|
141 |
+
# Save numpy array to temporary file
|
142 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
143 |
+
import scipy.io.wavfile as wav
|
144 |
+
wav.write(tmp_file.name, sample_rate, audio_array)
|
145 |
+
audio_path = tmp_file.name
|
146 |
+
else:
|
147 |
+
audio_path = audio_data
|
148 |
+
|
149 |
+
if not inference_engine.client:
|
150 |
+
return "๐ API token required for speech recognition."
|
151 |
+
|
152 |
+
config = MODEL_REGISTRY[TaskType.ASR]
|
153 |
+
result = await inference_engine._safe_inference_call(
|
154 |
+
inference_engine.client.automatic_speech_recognition,
|
155 |
+
audio_path,
|
156 |
+
model=config.name,
|
157 |
+
model_config=config
|
158 |
+
)
|
159 |
+
|
160 |
+
if result["success"]:
|
161 |
+
text = result["data"].get("text", "No transcription available")
|
162 |
+
return f"๐ค **Transcription:** {text}"
|
163 |
+
else:
|
164 |
+
return f"โ Error: {result['error']}"
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
return f"โ Processing error: {str(e)}"
|
168 |
|
169 |
+
async def process_chat(messages_json: str) -> str:
|
170 |
+
"""Process chat completion request."""
|
171 |
+
if not messages_json.strip():
|
172 |
+
return "โ ๏ธ Please enter a valid JSON message format."
|
173 |
+
|
174 |
try:
|
175 |
+
messages = json.loads(messages_json)
|
176 |
+
if not isinstance(messages, list) or not messages:
|
177 |
+
return "โ ๏ธ Messages must be a non-empty JSON array."
|
178 |
+
|
179 |
+
# Validate message structure
|
180 |
+
for msg in messages:
|
181 |
+
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
182 |
+
return "โ ๏ธ Each message must have 'role' and 'content' fields."
|
183 |
+
|
184 |
+
if not inference_engine.client:
|
185 |
+
return "๐ API token required for chat completion."
|
186 |
+
|
187 |
+
config = MODEL_REGISTRY[TaskType.CHAT]
|
188 |
+
result = await inference_engine._safe_inference_call(
|
189 |
+
inference_engine.client.chat.completions.create,
|
190 |
+
model=config.name,
|
191 |
+
messages=messages,
|
192 |
+
max_tokens=150,
|
193 |
+
model_config=config
|
194 |
+
)
|
195 |
+
|
196 |
+
if result["success"]:
|
197 |
+
response = result["data"]
|
198 |
+
if hasattr(response, "choices") and response.choices:
|
199 |
+
reply = response.choices[0].message.content
|
200 |
+
return f"๐ค **Assistant:** {reply}"
|
201 |
+
else:
|
202 |
+
return "๐ค **Assistant:** I'm here to help! How can I assist you today?"
|
203 |
+
else:
|
204 |
+
return f"โ Error: {result['error']}"
|
205 |
+
|
206 |
+
except json.JSONDecodeError:
|
207 |
+
return "โ ๏ธ Invalid JSON format. Please check your syntax."
|
208 |
except Exception as e:
|
209 |
+
return f"โ Processing error: {str(e)}"
|
210 |
+
|
211 |
+
async def process_fill_mask(text: str) -> str:
|
212 |
+
"""Process fill mask task."""
|
213 |
+
if not text or "[MASK]" not in text:
|
214 |
+
return "โ ๏ธ Input must contain the token `[MASK]`."
|
215 |
+
|
216 |
+
if not inference_engine.client:
|
217 |
+
return "๐ API token required for fill mask."
|
218 |
+
|
219 |
+
config = MODEL_REGISTRY[TaskType.FILL_MASK]
|
220 |
+
result = await inference_engine._safe_inference_call(
|
221 |
+
inference_engine.client.fill_mask,
|
|
|
|
|
|
|
|
|
222 |
text,
|
223 |
+
model=config.name,
|
224 |
+
model_config=config
|
225 |
)
|
226 |
+
|
227 |
+
if result["success"]:
|
228 |
+
predictions = result["data"]
|
229 |
+
if isinstance(predictions, list):
|
230 |
+
formatted_results = []
|
231 |
+
for i, pred in enumerate(predictions[:5], 1):
|
232 |
+
token = pred.get("token_str", "").strip()
|
233 |
+
score = pred.get("score", 0)
|
234 |
+
formatted_results.append(f"{i}. **{token}** (confidence: {score:.3f})")
|
235 |
+
return "๐ญ **Top Predictions:**\n" + "\n".join(formatted_results)
|
236 |
+
else:
|
237 |
+
return f"๐ญ **Result:** {predictions}"
|
238 |
+
else:
|
239 |
+
return f"โ Error: {result['error']}"
|
240 |
+
|
241 |
+
async def process_question_answering(question: str, context: str) -> str:
|
242 |
+
"""Process question answering task."""
|
243 |
if not question or not context:
|
244 |
+
return "โ ๏ธ Both question and context are required."
|
245 |
+
|
246 |
+
if not inference_engine.client:
|
247 |
+
return "๐ API token required for question answering."
|
248 |
+
|
249 |
+
config = MODEL_REGISTRY[TaskType.QA]
|
250 |
+
result = await inference_engine._safe_inference_call(
|
251 |
+
inference_engine.client.question_answering,
|
252 |
+
question=question.strip(),
|
253 |
+
context=context.strip(),
|
254 |
+
model=config.name,
|
255 |
+
model_config=config
|
256 |
)
|
257 |
+
|
258 |
+
if result["success"]:
|
259 |
+
answer_data = result["data"]
|
260 |
+
answer = answer_data.get("answer", "No answer found")
|
261 |
+
confidence = answer_data.get("score", 0)
|
262 |
+
return f"๐ก **Answer:** {answer}\n๐ **Confidence:** {confidence:.3f}"
|
263 |
+
else:
|
264 |
+
return f"โ Error: {result['error']}"
|
265 |
+
|
266 |
+
async def process_summarization(text: str) -> str:
|
267 |
+
"""Process text summarization."""
|
268 |
+
if not text or len(text.split()) < 10:
|
269 |
+
return "โ ๏ธ Please provide at least 10 words to summarize."
|
270 |
+
|
271 |
+
if not inference_engine.client:
|
272 |
+
return "๐ API token required for summarization."
|
273 |
+
|
274 |
+
config = MODEL_REGISTRY[TaskType.SUMMARIZATION]
|
275 |
+
result = await inference_engine._safe_inference_call(
|
276 |
+
inference_engine.client.summarization,
|
277 |
+
text.strip(),
|
278 |
+
model=config.name,
|
279 |
+
max_length=130,
|
280 |
+
min_length=30,
|
281 |
+
model_config=config
|
282 |
)
|
283 |
+
|
284 |
+
if result["success"]:
|
285 |
+
summary_data = result["data"]
|
286 |
+
if isinstance(summary_data, list) and summary_data:
|
287 |
+
summary = summary_data[0].get("summary_text", "No summary generated")
|
288 |
+
else:
|
289 |
+
summary = str(summary_data)
|
290 |
+
return f"๐ **Summary:** {summary}"
|
291 |
+
else:
|
292 |
+
return f"โ Error: {result['error']}"
|
293 |
+
|
294 |
+
async def process_text_generation(prompt: str) -> str:
|
295 |
+
"""Process text generation."""
|
296 |
if not prompt.strip():
|
297 |
+
return "โ ๏ธ Prompt cannot be empty."
|
298 |
+
|
299 |
+
if not inference_engine.client:
|
300 |
+
# Fallback for demo without API key
|
301 |
+
return f"๐ค **Generated Text:** {prompt} [This is a demo response - configure API token for real generation]"
|
302 |
+
|
303 |
+
config = MODEL_REGISTRY[TaskType.TEXT_GENERATION]
|
304 |
+
result = await inference_engine._safe_inference_call(
|
305 |
+
inference_engine.client.text_generation,
|
306 |
+
prompt.strip(),
|
307 |
+
model=config.name,
|
308 |
+
max_new_tokens=100,
|
309 |
+
temperature=0.7,
|
310 |
+
model_config=config
|
311 |
)
|
312 |
+
|
313 |
+
if result["success"]:
|
314 |
+
generated = result["data"]
|
315 |
+
if isinstance(generated, list) and generated:
|
316 |
+
text = generated[0].get("generated_text", prompt)
|
317 |
+
else:
|
318 |
+
text = str(generated)
|
319 |
+
return f"โ๏ธ **Generated Text:** {text}"
|
320 |
+
else:
|
321 |
+
return f"โ Error: {result['error']}"
|
322 |
+
|
323 |
+
async def process_image_classification(image_path) -> str:
|
324 |
+
"""Process image classification."""
|
325 |
+
if image_path is None:
|
326 |
+
return "โ ๏ธ Please upload an image."
|
327 |
+
|
328 |
+
if not inference_engine.client:
|
329 |
+
return "๐ API token required for image classification."
|
330 |
+
|
331 |
+
config = MODEL_REGISTRY[TaskType.IMAGE_CLASSIFICATION]
|
332 |
+
result = await inference_engine._safe_inference_call(
|
333 |
+
inference_engine.client.image_classification,
|
334 |
image_path,
|
335 |
+
model=config.name,
|
336 |
+
model_config=config
|
337 |
)
|
338 |
+
|
339 |
+
if result["success"]:
|
340 |
+
predictions = result["data"]
|
341 |
+
if isinstance(predictions, list):
|
342 |
+
formatted_results = []
|
343 |
+
for i, pred in enumerate(predictions[:5], 1):
|
344 |
+
label = pred.get("label", "Unknown")
|
345 |
+
score = pred.get("score", 0)
|
346 |
+
formatted_results.append(f"{i}. **{label}** ({score:.1%})")
|
347 |
+
return "๐ผ๏ธ **Image Classification:**\n" + "\n".join(formatted_results)
|
348 |
+
else:
|
349 |
+
return f"๐ผ๏ธ **Result:** {predictions}"
|
350 |
+
else:
|
351 |
+
return f"โ Error: {result['error']}"
|
352 |
+
|
353 |
+
async def process_feature_extraction(text: str) -> str:
|
354 |
+
"""Process feature extraction."""
|
355 |
if not text.strip():
|
356 |
+
return "โ ๏ธ Input text cannot be empty."
|
357 |
+
|
358 |
+
if not inference_engine.client:
|
359 |
+
return "๐ API token required for feature extraction."
|
360 |
+
|
361 |
+
config = MODEL_REGISTRY[TaskType.FEATURE_EXTRACTION]
|
362 |
+
result = await inference_engine._safe_inference_call(
|
363 |
+
inference_engine.client.feature_extraction,
|
364 |
+
text.strip(),
|
365 |
+
model=config.name,
|
366 |
+
model_config=config
|
367 |
)
|
368 |
+
|
369 |
+
if result["success"]:
|
370 |
+
embeddings = result["data"]
|
371 |
+
if isinstance(embeddings, list) and embeddings:
|
372 |
+
dim = len(embeddings[0]) if embeddings[0] else 0
|
373 |
+
sample = embeddings[0][:5] if dim >= 5 else embeddings[0]
|
374 |
+
return f"๐งฎ **Feature Vector:** Dimension: {dim}\n**Sample values:** {sample}..."
|
375 |
+
else:
|
376 |
+
return f"๐งฎ **Embeddings:** {str(embeddings)[:200]}..."
|
377 |
+
else:
|
378 |
+
return f"โ Error: {result['error']}"
|
379 |
+
|
380 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
381 |
+
# ๐จ GRADIO INTERFACE - MODERN & RESPONSIVE
|
382 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
383 |
+
|
384 |
+
def create_interface():
|
385 |
+
"""Create the main Gradio interface with modern design."""
|
386 |
+
|
387 |
+
with gr.Blocks(
|
388 |
+
theme=gr.themes.Soft(),
|
389 |
+
title="๐ AI Multi-Task Hub",
|
390 |
+
css="""
|
391 |
+
.gradio-container {
|
392 |
+
font-family: 'Inter', sans-serif;
|
393 |
+
}
|
394 |
+
.tab-nav button {
|
395 |
+
font-weight: 500;
|
396 |
+
}
|
397 |
+
"""
|
398 |
+
) as demo:
|
399 |
+
|
400 |
+
gr.Markdown("""
|
401 |
+
# ๐ **AI Multi-Task Hub**
|
402 |
+
### *Professional-grade AI inference across multiple domains*
|
403 |
+
|
404 |
+
**โก Real-time processing** | **๐ Secure** | **๐ฏ Production-ready**
|
405 |
+
""")
|
406 |
+
|
407 |
+
with gr.Tabs():
|
408 |
+
|
409 |
+
# ๐ค Speech Recognition Tab
|
410 |
+
with gr.TabItem("๐ค Speech Recognition", id="asr"):
|
411 |
+
with gr.Row():
|
412 |
+
with gr.Column(scale=1):
|
413 |
+
asr_input = gr.Audio(
|
414 |
+
sources=["upload", "microphone"],
|
415 |
+
type="numpy",
|
416 |
+
label="๐ Upload or ๐๏ธ Record Audio"
|
417 |
+
)
|
418 |
+
asr_button = gr.Button("๐ Process Audio", variant="primary")
|
419 |
+
|
420 |
+
with gr.Column(scale=1):
|
421 |
+
asr_output = gr.Textbox(
|
422 |
+
label="๐ Transcription Result",
|
423 |
+
lines=4,
|
424 |
+
placeholder="Audio transcription will appear here..."
|
425 |
+
)
|
426 |
+
|
427 |
+
# Wire up the processing
|
428 |
+
asr_button.click(
|
429 |
+
fn=lambda audio: asyncio.run(process_speech_recognition(audio)),
|
430 |
+
inputs=[asr_input],
|
431 |
+
outputs=[asr_output]
|
432 |
+
)
|
433 |
+
|
434 |
+
# ๐ฌ Chat Tab
|
435 |
+
with gr.TabItem("๐ฌ Chat", id="chat"):
|
436 |
+
with gr.Row():
|
437 |
+
with gr.Column(scale=1):
|
438 |
+
chat_input = gr.Textbox(
|
439 |
+
label="๐ญ Messages (JSON Format)",
|
440 |
+
lines=5,
|
441 |
+
placeholder='[{"role":"user","content":"Hello, how are you?"}]',
|
442 |
+
value='[{"role":"user","content":"Hello! Tell me about artificial intelligence."}]'
|
443 |
+
)
|
444 |
+
chat_button = gr.Button("๐ฌ Send Message", variant="primary")
|
445 |
+
|
446 |
+
gr.Markdown("""
|
447 |
+
**๐ Format Examples:**
|
448 |
+
- `[{"role":"user","content":"Your message here"}]`
|
449 |
+
- `[{"role":"system","content":"You are helpful"},{"role":"user","content":"Hi"}]`
|
450 |
+
""")
|
451 |
+
|
452 |
+
with gr.Column(scale=1):
|
453 |
+
chat_output = gr.Textbox(
|
454 |
+
label="๐ค AI Response",
|
455 |
+
lines=6,
|
456 |
+
placeholder="AI response will appear here..."
|
457 |
+
)
|
458 |
+
|
459 |
+
chat_button.click(
|
460 |
+
fn=lambda msg: asyncio.run(process_chat(msg)),
|
461 |
+
inputs=[chat_input],
|
462 |
+
outputs=[chat_output]
|
463 |
+
)
|
464 |
+
|
465 |
+
# ๐ญ Fill Mask Tab
|
466 |
+
with gr.TabItem("๐ญ Fill Mask", id="mask"):
|
467 |
+
with gr.Row():
|
468 |
+
with gr.Column(scale=1):
|
469 |
+
mask_input = gr.Textbox(
|
470 |
+
label="๐ฏ Text with [MASK] Token",
|
471 |
+
lines=3,
|
472 |
+
placeholder="The capital of France is [MASK].",
|
473 |
+
value="The most popular programming language is [MASK]."
|
474 |
+
)
|
475 |
+
mask_button = gr.Button("๐ Predict Mask", variant="primary")
|
476 |
+
|
477 |
+
with gr.Column(scale=1):
|
478 |
+
mask_output = gr.Textbox(
|
479 |
+
label="๐ญ Predictions",
|
480 |
+
lines=6,
|
481 |
+
placeholder="Mask predictions will appear here..."
|
482 |
+
)
|
483 |
+
|
484 |
+
mask_button.click(
|
485 |
+
fn=lambda text: asyncio.run(process_fill_mask(text)),
|
486 |
+
inputs=[mask_input],
|
487 |
+
outputs=[mask_output]
|
488 |
+
)
|
489 |
+
|
490 |
+
# โ Question Answering Tab
|
491 |
+
with gr.TabItem("โ Q&A", id="qa"):
|
492 |
+
with gr.Row():
|
493 |
+
with gr.Column(scale=1):
|
494 |
+
qa_question = gr.Textbox(
|
495 |
+
label="โ Question",
|
496 |
+
placeholder="What is machine learning?",
|
497 |
+
value="What is the main benefit of cloud computing?"
|
498 |
+
)
|
499 |
+
qa_context = gr.Textbox(
|
500 |
+
label="๐ Context",
|
501 |
+
lines=5,
|
502 |
+
placeholder="Provide context for the question...",
|
503 |
+
value="Cloud computing is a technology that allows users to access computing resources like servers, storage, and applications over the internet. It offers scalability, cost-effectiveness, and flexibility for businesses and individuals."
|
504 |
+
)
|
505 |
+
qa_button = gr.Button("๐ Find Answer", variant="primary")
|
506 |
+
|
507 |
+
with gr.Column(scale=1):
|
508 |
+
qa_output = gr.Textbox(
|
509 |
+
label="๐ก Answer",
|
510 |
+
lines=6,
|
511 |
+
placeholder="Answer will appear here..."
|
512 |
+
)
|
513 |
+
|
514 |
+
qa_button.click(
|
515 |
+
fn=lambda q, c: asyncio.run(process_question_answering(q, c)),
|
516 |
+
inputs=[qa_question, qa_context],
|
517 |
+
outputs=[qa_output]
|
518 |
+
)
|
519 |
+
|
520 |
+
# ๐ Summarization Tab
|
521 |
+
with gr.TabItem("๐ Summarization", id="summary"):
|
522 |
+
with gr.Row():
|
523 |
+
with gr.Column(scale=1):
|
524 |
+
sum_input = gr.Textbox(
|
525 |
+
label="๐ Text to Summarize",
|
526 |
+
lines=8,
|
527 |
+
placeholder="Enter long text to summarize...",
|
528 |
+
value="Artificial intelligence (AI) is a rapidly evolving field that encompasses machine learning, deep learning, natural language processing, and computer vision. AI systems are being deployed across various industries including healthcare, finance, transportation, and entertainment. Machine learning algorithms can analyze vast amounts of data to identify patterns and make predictions. Deep learning, a subset of machine learning, uses neural networks with multiple layers to process complex data. Natural language processing enables computers to understand and generate human language. Computer vision allows machines to interpret and analyze visual information. The applications of AI are virtually limitless, from autonomous vehicles to medical diagnosis, from recommendation systems to fraud detection."
|
529 |
+
)
|
530 |
+
sum_button = gr.Button("๐ Summarize", variant="primary")
|
531 |
+
|
532 |
+
with gr.Column(scale=1):
|
533 |
+
sum_output = gr.Textbox(
|
534 |
+
label="โจ Summary",
|
535 |
+
lines=6,
|
536 |
+
placeholder="Summary will appear here..."
|
537 |
+
)
|
538 |
+
|
539 |
+
sum_button.click(
|
540 |
+
fn=lambda text: asyncio.run(process_summarization(text)),
|
541 |
+
inputs=[sum_input],
|
542 |
+
outputs=[sum_output]
|
543 |
+
)
|
544 |
+
|
545 |
+
# โ๏ธ Text Generation Tab
|
546 |
+
with gr.TabItem("โ๏ธ Text Generation", id="generation"):
|
547 |
+
with gr.Row():
|
548 |
+
with gr.Column(scale=1):
|
549 |
+
gen_input = gr.Textbox(
|
550 |
+
label="โจ Creative Prompt",
|
551 |
+
lines=4,
|
552 |
+
placeholder="Enter your creative prompt...",
|
553 |
+
value="The future of artificial intelligence will be"
|
554 |
+
)
|
555 |
+
gen_button = gr.Button("โ๏ธ Generate Text", variant="primary")
|
556 |
+
|
557 |
+
with gr.Column(scale=1):
|
558 |
+
gen_output = gr.Textbox(
|
559 |
+
label="๐จ Generated Content",
|
560 |
+
lines=6,
|
561 |
+
placeholder="Generated text will appear here..."
|
562 |
+
)
|
563 |
+
|
564 |
+
gen_button.click(
|
565 |
+
fn=lambda prompt: asyncio.run(process_text_generation(prompt)),
|
566 |
+
inputs=[gen_input],
|
567 |
+
outputs=[gen_output]
|
568 |
+
)
|
569 |
+
|
570 |
+
# ๐ผ๏ธ Image Classification Tab
|
571 |
+
with gr.TabItem("๐ผ๏ธ Image Classification", id="image"):
|
572 |
+
with gr.Row():
|
573 |
+
with gr.Column(scale=1):
|
574 |
+
img_input = gr.Image(
|
575 |
+
type="filepath",
|
576 |
+
label="๐ผ๏ธ Upload Image"
|
577 |
+
)
|
578 |
+
img_button = gr.Button("๐ Classify Image", variant="primary")
|
579 |
+
|
580 |
+
with gr.Column(scale=1):
|
581 |
+
img_output = gr.Textbox(
|
582 |
+
label="๐ท๏ธ Classification Results",
|
583 |
+
lines=6,
|
584 |
+
placeholder="Image classification results will appear here..."
|
585 |
+
)
|
586 |
+
|
587 |
+
img_button.click(
|
588 |
+
fn=lambda img: asyncio.run(process_image_classification(img)),
|
589 |
+
inputs=[img_input],
|
590 |
+
outputs=[img_output]
|
591 |
+
)
|
592 |
+
|
593 |
+
# ๐งฎ Feature Extraction Tab
|
594 |
+
with gr.TabItem("๐งฎ Feature Extraction", id="features"):
|
595 |
+
with gr.Row():
|
596 |
+
with gr.Column(scale=1):
|
597 |
+
fe_input = gr.Textbox(
|
598 |
+
label="๐ Input Text",
|
599 |
+
lines=4,
|
600 |
+
placeholder="Enter text for feature extraction...",
|
601 |
+
value="Machine learning is transforming the world."
|
602 |
+
)
|
603 |
+
fe_button = gr.Button("๐งฎ Extract Features", variant="primary")
|
604 |
+
|
605 |
+
with gr.Column(scale=1):
|
606 |
+
fe_output = gr.Textbox(
|
607 |
+
label="๐ข Feature Vector",
|
608 |
+
lines=6,
|
609 |
+
placeholder="Feature vectors will appear here..."
|
610 |
+
)
|
611 |
+
|
612 |
+
fe_button.click(
|
613 |
+
fn=lambda text: asyncio.run(process_feature_extraction(text)),
|
614 |
+
inputs=[fe_input],
|
615 |
+
outputs=[fe_output]
|
616 |
+
)
|
617 |
+
|
618 |
+
# Footer
|
619 |
+
gr.Markdown("""
|
620 |
+
---
|
621 |
+
**๐ง Configuration:** Set `HF_API_TOKEN` environment variable for full functionality
|
622 |
+
**โก Performance:** Optimized for production workloads
|
623 |
+
**๐ก๏ธ Security:** Enterprise-grade error handling and validation
|
624 |
+
""")
|
625 |
+
|
626 |
+
return demo
|
627 |
+
|
628 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
629 |
+
# ๐ APPLICATION ENTRY POINT
|
630 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
631 |
+
|
632 |
+
if __name__ == "__main__":
|
633 |
+
app = create_interface()
|
634 |
+
|
635 |
+
# Production-ready launch configuration
|
636 |
+
app.launch(
|
637 |
+
server_name="0.0.0.0",
|
638 |
+
server_port=7860,
|
639 |
+
share=False, # Set to True for public sharing
|
640 |
+
show_error=True,
|
641 |
+
show_tips=True,
|
642 |
+
inbrowser=True
|
643 |
+
)
|
644 |
+
|