burtenshaw commited on
Commit
29272e4
·
1 Parent(s): 54cffe3

first commit

Browse files
Files changed (5) hide show
  1. .python-version +1 -0
  2. app.py +916 -0
  3. pyproject.toml +54 -0
  4. requirements.txt +182 -0
  5. uv.lock +0 -0
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
app.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AutoTrain Gradio MCP Server - All-in-One
3
+
4
+ This single Gradio app:
5
+ 1. Provides a web interface for managing AutoTrain jobs
6
+ 2. Automatically exposes MCP tools at /gradio_api/mcp/sse
7
+ 3. Handles all AutoTrain operations directly (no FastAPI needed)
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import time
13
+ import uuid
14
+ import threading
15
+ from datetime import datetime
16
+ from typing import List, Dict, Any
17
+ import socket
18
+
19
+ import gradio as gr
20
+ import pandas as pd
21
+ import wandb
22
+ from autotrain.project import AutoTrainProject
23
+ from autotrain.params import (
24
+ LLMTrainingParams,
25
+ TextClassificationParams,
26
+ ImageClassificationParams,
27
+ )
28
+
29
+ # Simple JSON-based storage (replace with SQLite if needed)
30
+ RUNS_FILE = "training_runs.json"
31
+ WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "autotrain-mcp")
32
+
33
+
34
+ def load_runs() -> List[Dict[str, Any]]:
35
+ """Load training runs from JSON file"""
36
+ if os.path.exists(RUNS_FILE):
37
+ try:
38
+ with open(RUNS_FILE, "r") as f:
39
+ return json.load(f)
40
+ except (json.JSONDecodeError, IOError):
41
+ return []
42
+ return []
43
+
44
+
45
+ def save_runs(runs: List[Dict[str, Any]]):
46
+ """Save training runs to JSON file"""
47
+ with open(RUNS_FILE, "w") as f:
48
+ json.dump(runs, f, indent=2)
49
+
50
+
51
+ def get_status_emoji(status: str) -> str:
52
+ """Get emoji for training status"""
53
+ emoji_map = {
54
+ "pending": "⏳",
55
+ "running": "🏃",
56
+ "completed": "✅",
57
+ "failed": "❌",
58
+ "cancelled": "⏹️",
59
+ }
60
+ return emoji_map.get(status.lower(), "❓")
61
+
62
+
63
+ def create_autotrain_params(
64
+ task: str,
65
+ base_model: str,
66
+ project_name: str,
67
+ dataset_path: str,
68
+ epochs: int,
69
+ batch_size: int,
70
+ learning_rate: float,
71
+ **kwargs,
72
+ ):
73
+ """Create AutoTrain parameter object based on task type"""
74
+ common_params = {
75
+ "model": base_model,
76
+ "project_name": project_name,
77
+ "data_path": dataset_path,
78
+ "train_split": kwargs.get("train_split", "train"),
79
+ "valid_split": kwargs.get("valid_split"),
80
+ "epochs": epochs,
81
+ "batch_size": batch_size,
82
+ "lr": learning_rate,
83
+ "log": "wandb",
84
+ # Required defaults
85
+ "warmup_ratio": 0.1,
86
+ "gradient_accumulation": 1,
87
+ "optimizer": "adamw_torch",
88
+ "scheduler": "linear",
89
+ "weight_decay": 0.01,
90
+ "max_grad_norm": 1.0,
91
+ "seed": 42,
92
+ "logging_steps": 10,
93
+ "auto_find_batch_size": False,
94
+ "mixed_precision": "no",
95
+ "save_total_limit": 1,
96
+ "eval_strategy": "epoch",
97
+ }
98
+
99
+ if task == "text-classification":
100
+ return TextClassificationParams(
101
+ **common_params,
102
+ text_column=kwargs.get("text_column", "text"),
103
+ target_column=kwargs.get("target_column", "label"),
104
+ max_seq_length=kwargs.get("max_seq_length", 128),
105
+ early_stopping_patience=3,
106
+ early_stopping_threshold=0.01,
107
+ )
108
+
109
+ elif task.startswith("llm-"):
110
+ trainer_map = {
111
+ "llm-sft": "sft",
112
+ "llm-dpo": "dpo",
113
+ "llm-orpo": "orpo",
114
+ "llm-reward": "reward",
115
+ }
116
+
117
+ return LLMTrainingParams(
118
+ **{
119
+ k: v
120
+ for k, v in common_params.items()
121
+ if k not in ["early_stopping_patience", "early_stopping_threshold"]
122
+ },
123
+ text_column=kwargs.get("text_column", "messages"),
124
+ block_size=kwargs.get("block_size", 2048),
125
+ peft=kwargs.get("use_peft", True),
126
+ quantization=kwargs.get("quantization", "int4"),
127
+ trainer=trainer_map[task],
128
+ chat_template="tokenizer",
129
+ # LLM-specific defaults
130
+ add_eos_token=True,
131
+ model_max_length=2048,
132
+ padding="right",
133
+ use_flash_attention_2=False,
134
+ disable_gradient_checkpointing=False,
135
+ target_modules="all-linear",
136
+ merge_adapter=False,
137
+ lora_r=16,
138
+ lora_alpha=32,
139
+ lora_dropout=0.05,
140
+ model_ref=None,
141
+ dpo_beta=0.1,
142
+ max_prompt_length=512,
143
+ max_completion_length=1024,
144
+ prompt_text_column="prompt",
145
+ rejected_text_column="rejected",
146
+ unsloth=False,
147
+ distributed_backend="accelerate",
148
+ )
149
+
150
+ elif task == "image-classification":
151
+ return ImageClassificationParams(
152
+ **common_params,
153
+ image_column=kwargs.get("image_column", "image"),
154
+ target_column=kwargs.get("target_column", "label"),
155
+ )
156
+
157
+ else:
158
+ raise ValueError(f"Unsupported task type: {task}")
159
+
160
+
161
+ def run_training_background(run_id: str, params: Any, backend: str):
162
+ """Run training job in background thread"""
163
+ runs = load_runs()
164
+
165
+ # Update status to running
166
+ for run in runs:
167
+ if run["run_id"] == run_id:
168
+ run["status"] = "running"
169
+ run["started_at"] = datetime.utcnow().isoformat()
170
+ break
171
+ save_runs(runs)
172
+
173
+ try:
174
+ # Initialize W&B
175
+ wandb_run = wandb.init(
176
+ project=WANDB_PROJECT,
177
+ name=f"{params.project_name}-{int(time.time())}",
178
+ tags=["autotrain", "mcp"],
179
+ config={
180
+ "base_model": params.model,
181
+ "dataset": params.data_path,
182
+ "epochs": params.epochs,
183
+ "batch_size": params.batch_size,
184
+ "learning_rate": params.lr,
185
+ "backend": backend,
186
+ },
187
+ )
188
+
189
+ wandb_url = (
190
+ wandb_run.url if wandb_run.url else f"https://wandb.ai/{WANDB_PROJECT}"
191
+ )
192
+
193
+ # Update with W&B URL
194
+ runs = load_runs()
195
+ for run in runs:
196
+ if run["run_id"] == run_id:
197
+ run["wandb_url"] = wandb_url
198
+ break
199
+ save_runs(runs)
200
+
201
+ # Create and start AutoTrain project
202
+ project = AutoTrainProject(params=params, backend=backend, process=True)
203
+ job_id = project.create()
204
+
205
+ print(f"Training started for run {run_id} with job ID: {job_id}")
206
+
207
+ # For demo purposes, simulate training completion after a short delay
208
+ time.sleep(10) # In real implementation, monitor actual training
209
+
210
+ # Update status to completed
211
+ runs = load_runs()
212
+ for run in runs:
213
+ if run["run_id"] == run_id:
214
+ run["status"] = "completed"
215
+ run["completed_at"] = datetime.utcnow().isoformat()
216
+ break
217
+ save_runs(runs)
218
+
219
+ wandb.finish()
220
+
221
+ except Exception as e:
222
+ print(f"Training failed for run {run_id}: {str(e)}")
223
+
224
+ # Update status to failed
225
+ runs = load_runs()
226
+ for run in runs:
227
+ if run["run_id"] == run_id:
228
+ run["status"] = "failed"
229
+ run["error_message"] = str(e)
230
+ run["completed_at"] = datetime.utcnow().isoformat()
231
+ break
232
+ save_runs(runs)
233
+
234
+ if wandb.run:
235
+ wandb.finish()
236
+
237
+
238
+ # MCP Tool Functions (these automatically become MCP tools)
239
+ def start_training_job(
240
+ task: str = "text-classification",
241
+ project_name: str = "test-project",
242
+ base_model: str = "distilbert-base-uncased",
243
+ dataset_path: str = "imdb",
244
+ epochs: str = "1",
245
+ batch_size: str = "8",
246
+ learning_rate: str = "2e-5",
247
+ backend: str = "local",
248
+ ) -> str:
249
+ """
250
+ Start a new AutoTrain training job.
251
+
252
+ Args:
253
+ task: Type of training task (text-classification, llm-sft,
254
+ llm-dpo, llm-orpo, image-classification)
255
+ project_name: Name for the training project
256
+ base_model: Base model from Hugging Face Hub
257
+ (e.g., distilbert-base-uncased)
258
+ dataset_path: Dataset path or HF dataset name (e.g., imdb)
259
+ epochs: Number of training epochs (default: 3)
260
+ batch_size: Training batch size (default: 16)
261
+ learning_rate: Learning rate for training (default: 2e-5)
262
+ backend: Training backend to use (default: local)
263
+
264
+ Returns:
265
+ Status message with run ID and details
266
+ """
267
+ try:
268
+ # Convert string parameters
269
+ epochs_int = int(epochs)
270
+ batch_size_int = int(batch_size)
271
+ learning_rate_float = float(learning_rate)
272
+
273
+ # Generate run ID
274
+ run_id = str(uuid.uuid4())
275
+
276
+ # Create run record
277
+ run_data = {
278
+ "run_id": run_id,
279
+ "project_name": project_name,
280
+ "task": task,
281
+ "base_model": base_model,
282
+ "dataset_path": dataset_path,
283
+ "status": "pending",
284
+ "created_at": datetime.utcnow().isoformat(),
285
+ "updated_at": datetime.utcnow().isoformat(),
286
+ "config": {
287
+ "task": task,
288
+ "epochs": epochs_int,
289
+ "batch_size": batch_size_int,
290
+ "learning_rate": learning_rate_float,
291
+ "backend": backend,
292
+ },
293
+ }
294
+
295
+ # Save to storage
296
+ runs = load_runs()
297
+ runs.append(run_data)
298
+ save_runs(runs)
299
+
300
+ # Create AutoTrain parameters
301
+ params = create_autotrain_params(
302
+ task=task,
303
+ base_model=base_model,
304
+ project_name=project_name,
305
+ dataset_path=dataset_path,
306
+ epochs=epochs_int,
307
+ batch_size=batch_size_int,
308
+ learning_rate=learning_rate_float,
309
+ )
310
+
311
+ # Start training in background
312
+ thread = threading.Thread(
313
+ target=run_training_background, args=(run_id, params, backend)
314
+ )
315
+ thread.daemon = True
316
+ thread.start()
317
+
318
+ return f"""✅ Training job submitted successfully!
319
+
320
+ Run ID: {run_id}
321
+ Project: {project_name}
322
+ Task: {task}
323
+ Model: {base_model}
324
+ Dataset: {dataset_path}
325
+
326
+ Configuration:
327
+ • Epochs: {epochs}
328
+ • Batch Size: {batch_size}
329
+ • Learning Rate: {learning_rate}
330
+ • Backend: {backend}
331
+
332
+ 🔗 Monitor progress:
333
+ • Gradio UI: http://localhost:7860
334
+ • W&B tracking will be available once training starts
335
+
336
+ 💡 Use get_training_runs() to check status"""
337
+
338
+ except Exception as e:
339
+ return f"❌ Error submitting job: {str(e)}"
340
+
341
+
342
+ def get_training_runs(limit: str = "20", status: str = "") -> str:
343
+ """
344
+ Get list of training runs with their status and details.
345
+
346
+ Args:
347
+ limit: Maximum number of runs to return (default: 20)
348
+ status: Filter by run status (pending, running, completed,
349
+ failed, cancelled)
350
+
351
+ Returns:
352
+ Formatted list of training runs with status and links
353
+ """
354
+ try:
355
+ runs = load_runs()
356
+
357
+ # Filter by status if provided
358
+ if status:
359
+ runs = [run for run in runs if run.get("status") == status]
360
+
361
+ # Apply limit
362
+ runs = runs[-int(limit) :]
363
+
364
+ if not runs:
365
+ return "No training runs found. Start a new training job to see it here!"
366
+
367
+ runs_text = f"📊 Training Runs (showing {len(runs)}):\n\n"
368
+
369
+ for run in reversed(runs): # Show newest first
370
+ status_emoji = get_status_emoji(run["status"])
371
+
372
+ # Format run display with line break
373
+ run_display = (
374
+ f"{status_emoji} **{run['project_name']}** ({run['run_id'][:8]}...)"
375
+ )
376
+ runs_text += f"{run_display}\n"
377
+ runs_text += f" Task: {run['task']}\n"
378
+ runs_text += f" Model: {run['base_model']}\n"
379
+ runs_text += f" Status: {run['status'].title()}\n"
380
+ runs_text += f" Created: {run['created_at']}\n"
381
+
382
+ if run.get("wandb_url"):
383
+ runs_text += f" 🔗 W&B: {run['wandb_url']}\n"
384
+
385
+ if run.get("error_message"):
386
+ runs_text += f" ❌ Error: {run['error_message']}\n"
387
+
388
+ runs_text += "\n"
389
+
390
+ return runs_text
391
+
392
+ except Exception as e:
393
+ return f"❌ Error fetching runs: {str(e)}"
394
+
395
+
396
+ def get_run_details(run_id: str) -> str:
397
+ """
398
+ Get detailed information about a specific training run.
399
+
400
+ Args:
401
+ run_id: ID of the training run (can be partial ID)
402
+
403
+ Returns:
404
+ Detailed run information including config and status
405
+ """
406
+ try:
407
+ runs = load_runs()
408
+
409
+ # Find run by full or partial ID
410
+ found_run = None
411
+ for run in runs:
412
+ if run["run_id"] == run_id or run["run_id"].startswith(run_id):
413
+ found_run = run
414
+ break
415
+
416
+ if not found_run:
417
+ return f"❌ Training run {run_id} not found"
418
+
419
+ run = found_run
420
+ details_text = f"""📋 Training Run Details
421
+
422
+ **Run ID:** {run["run_id"]}
423
+ **Project:** {run["project_name"]}
424
+ **Task:** {run["task"]}
425
+ **Model:** {run["base_model"]}
426
+ **Dataset:** {run["dataset_path"]}
427
+ **Status:** {run["status"].title()}
428
+
429
+ **Timestamps:**
430
+ • Created: {run["created_at"]}
431
+ • Updated: {run.get("updated_at", "N/A")}"""
432
+
433
+ if run.get("started_at"):
434
+ details_text += f"\n• Started: {run['started_at']}"
435
+ if run.get("completed_at"):
436
+ details_text += f"\n• Completed: {run['completed_at']}"
437
+
438
+ if run.get("wandb_url"):
439
+ details_text += f"\n\n🔗 **W&B Dashboard:** {run['wandb_url']}"
440
+
441
+ if run.get("error_message"):
442
+ details_text += f"\n\n❌ **Error:** {run['error_message']}"
443
+
444
+ if run.get("config"):
445
+ config = run["config"]
446
+ details_text += "\n\n⚙️ **Training Configuration:**"
447
+ details_text += f"\n• Epochs: {config.get('epochs')}"
448
+ details_text += f"\n• Batch Size: {config.get('batch_size')}"
449
+ details_text += f"\n• Learning Rate: {config.get('learning_rate')}"
450
+ details_text += f"\n• Backend: {config.get('backend')}"
451
+
452
+ return details_text
453
+
454
+ except Exception as e:
455
+ return f"❌ Error fetching run details: {str(e)}"
456
+
457
+
458
+ def get_task_recommendations(
459
+ task: str = "text-classification", dataset_size: str = "medium"
460
+ ) -> str:
461
+ """
462
+ Get training recommendations for a specific task type.
463
+
464
+ Args:
465
+ task: Task type (text-classification, llm-sft, image-classification)
466
+ dataset_size: Size of dataset (small, medium, large)
467
+
468
+ Returns:
469
+ Recommended models, parameters, and best practices
470
+ """
471
+ recommendations = {
472
+ "text-classification": {
473
+ "models": ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"],
474
+ "params": {"batch_size": 16, "learning_rate": 2e-5, "epochs": 3},
475
+ "backends": ["local", "spaces-t4-small"],
476
+ "notes": [
477
+ "Good for sentiment analysis",
478
+ "Works well with IMDB, AG News datasets",
479
+ ],
480
+ },
481
+ "llm-sft": {
482
+ "models": [
483
+ "microsoft/DialoGPT-medium",
484
+ "HuggingFaceTB/SmolLM2-1.7B-Instruct",
485
+ ],
486
+ "params": {"batch_size": 1, "learning_rate": 1e-5, "epochs": 3},
487
+ "backends": ["spaces-t4-medium", "spaces-a10g-large"],
488
+ "notes": ["Use PEFT for efficiency", "Ensure proper chat formatting"],
489
+ },
490
+ "image-classification": {
491
+ "models": ["google/vit-base-patch16-224", "microsoft/resnet-50"],
492
+ "params": {"batch_size": 32, "learning_rate": 2e-5, "epochs": 5},
493
+ "backends": ["local", "spaces-t4-small"],
494
+ "notes": ["Ensure images are preprocessed", "Works with CIFAR, ImageNet"],
495
+ },
496
+ }
497
+
498
+ rec = recommendations.get(
499
+ task,
500
+ {
501
+ "models": [],
502
+ "params": {},
503
+ "backends": ["local"],
504
+ "notes": ["No specific recommendations available"],
505
+ },
506
+ )
507
+
508
+ rec_text = f"""🎯 Training Recommendations for {task.title()} \
509
+ ({dataset_size} dataset)
510
+
511
+ **Recommended Models:**
512
+ {chr(10).join(f"• {model}" for model in rec["models"])}
513
+
514
+ **Recommended Parameters:**
515
+ {chr(10).join(f"• {k}: {v}" for k, v in rec["params"].items())}
516
+
517
+ **Backend Suggestions:**
518
+ {chr(10).join(f"• {backend}" for backend in rec["backends"])}
519
+
520
+ **Best Practices:**
521
+ {chr(10).join(f"• {note}" for note in rec["notes"])}"""
522
+
523
+ return rec_text
524
+
525
+
526
+ def get_system_status(random_string: str = "") -> str:
527
+ """
528
+ Get AutoTrain system status and capabilities.
529
+
530
+ Returns:
531
+ System status, available tasks, backends, and statistics
532
+ """
533
+ try:
534
+ runs = load_runs()
535
+
536
+ # Calculate stats
537
+ total_runs = len(runs)
538
+ running_runs = len([r for r in runs if r.get("status") == "running"])
539
+ completed_runs = len([r for r in runs if r.get("status") == "completed"])
540
+ failed_runs = len([r for r in runs if r.get("status") == "failed"])
541
+
542
+ available_tasks = [
543
+ "text-classification",
544
+ "llm-sft",
545
+ "llm-dpo",
546
+ "llm-orpo",
547
+ "image-classification",
548
+ ]
549
+
550
+ available_backends = [
551
+ "local",
552
+ "spaces-t4-small",
553
+ "spaces-t4-medium",
554
+ "spaces-a10g-large",
555
+ "spaces-a10g-small",
556
+ "spaces-a100-large",
557
+ "spaces-l4x1",
558
+ "spaces-l4x4",
559
+ ]
560
+
561
+ status_text = f"""🚀 AutoTrain Gradio MCP Server - System Status
562
+
563
+ **Server Status:** Running
564
+ **Total Runs:** {total_runs}
565
+ **Active Runs:** {running_runs}
566
+ **Completed Runs:** {completed_runs}
567
+ **Failed Runs:** {failed_runs}
568
+
569
+ **Available Tasks:** {len(available_tasks)}
570
+ {chr(10).join(f" • {task}" for task in available_tasks)}
571
+
572
+ **Available Backends:** {len(available_backends)}
573
+ {chr(10).join(f" • {backend}" for backend in available_backends[:10])}
574
+ {
575
+ f" ... and {len(available_backends) - 10} more"
576
+ if len(available_backends) > 10
577
+ else ""
578
+ }
579
+
580
+ 💡 **Access Points:**
581
+ • Gradio UI: http://localhost:7860
582
+ • MCP Server: http://localhost:7860/gradio_api/mcp/sse
583
+ • MCP Schema: http://localhost:7860/gradio_api/mcp/schema
584
+
585
+ 🛠️ **W&B Integration:**
586
+ • Project: {WANDB_PROJECT}
587
+ • Set WANDB_PROJECT environment variable to customize"""
588
+
589
+ return status_text
590
+
591
+ except Exception as e:
592
+ return f"❌ Error getting system status: {str(e)}"
593
+
594
+
595
+ def refresh_data(random_string: str = "") -> str:
596
+ """Refresh data for UI updates"""
597
+ return "Data refreshed successfully"
598
+
599
+
600
+ def load_initial_data(random_string: str = "") -> str:
601
+ """Load initial data for the application"""
602
+ return "Initial data loaded successfully"
603
+
604
+
605
+ # Web UI Functions
606
+ def fetch_runs_for_ui():
607
+ """Fetch runs for the web interface table"""
608
+ try:
609
+ runs = load_runs()
610
+
611
+ if not runs:
612
+ return pd.DataFrame(
613
+ {
614
+ "Status": [],
615
+ "Project": [],
616
+ "Task": [],
617
+ "Model": [],
618
+ "Created": [],
619
+ "W&B Link": [],
620
+ "Run ID": [],
621
+ }
622
+ )
623
+
624
+ data = []
625
+ for run in reversed(runs): # Newest first
626
+ wandb_link = ""
627
+ if run.get("wandb_url"):
628
+ wandb_link = (
629
+ f'<a href="{run["wandb_url"]}" target="_blank">View W&B</a>'
630
+ )
631
+
632
+ data.append(
633
+ {
634
+ "Status": f"{get_status_emoji(run['status'])} {run['status'].title()}",
635
+ "Project": run["project_name"],
636
+ "Task": run["task"].replace("-", " ").title(),
637
+ "Model": run["base_model"],
638
+ "Created": run["created_at"][:16].replace("T", " "),
639
+ "W&B Link": wandb_link,
640
+ "Run ID": run["run_id"][:8] + "...",
641
+ }
642
+ )
643
+
644
+ return pd.DataFrame(data)
645
+
646
+ except Exception as e:
647
+ return pd.DataFrame({"Error": [f"Failed to fetch runs: {str(e)}"]})
648
+
649
+
650
+ def submit_training_job_ui(
651
+ task,
652
+ project_name,
653
+ base_model,
654
+ dataset_path,
655
+ epochs,
656
+ batch_size,
657
+ learning_rate,
658
+ backend,
659
+ ):
660
+ """Submit training job from web UI"""
661
+ if not all([task, project_name, base_model, dataset_path]):
662
+ return "❌ Please fill in all required fields", fetch_runs_for_ui()
663
+
664
+ result = start_training_job(
665
+ task=task,
666
+ project_name=project_name,
667
+ base_model=base_model,
668
+ dataset_path=dataset_path,
669
+ epochs=str(epochs),
670
+ batch_size=str(batch_size),
671
+ learning_rate=str(learning_rate),
672
+ backend=backend,
673
+ )
674
+
675
+ return result, fetch_runs_for_ui()
676
+
677
+
678
+ # Create Gradio Interface
679
+ with gr.Blocks(
680
+ title="AutoTrain Gradio MCP Server",
681
+ theme=gr.themes.Soft(),
682
+ css="""
683
+ .gradio-container {
684
+ max-width: 1200px !important;
685
+ }
686
+ """,
687
+ ) as app:
688
+ gr.Markdown("""
689
+ # 🚀 AutoTrain Gradio MCP Server
690
+
691
+ **All-in-One Solution:** Web UI + MCP Server + AutoTrain Integration
692
+
693
+ • **Web Interface**: Manage training jobs through this UI
694
+ • **MCP Server**: AI assistants can use tools at `http://localhost:7860/gradio_api/mcp/sse`
695
+ • **Direct Integration**: No FastAPI needed - everything runs in Gradio
696
+ """)
697
+
698
+ with gr.Tabs():
699
+ # Dashboard Tab
700
+ with gr.Tab("📊 Dashboard"):
701
+ with gr.Row():
702
+ with gr.Column(scale=3):
703
+ gr.Markdown("## Training Runs")
704
+ refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
705
+ runs_table = gr.Dataframe(
706
+ value=fetch_runs_for_ui(), interactive=False
707
+ )
708
+
709
+ with gr.Column(scale=1):
710
+ gr.Markdown("## Quick Stats")
711
+ stats = gr.Textbox(
712
+ value=get_system_status(), interactive=False, lines=15
713
+ )
714
+
715
+ # Start Training Tab
716
+ with gr.Tab("🏃 Start Training"):
717
+ gr.Markdown("## Submit New Training Job")
718
+
719
+ with gr.Row():
720
+ with gr.Column():
721
+ task_dropdown = gr.Dropdown(
722
+ choices=[
723
+ "text-classification",
724
+ "llm-sft",
725
+ "llm-dpo",
726
+ "llm-orpo",
727
+ "image-classification",
728
+ ],
729
+ label="Task Type",
730
+ value="text-classification",
731
+ )
732
+
733
+ project_name = gr.Textbox(
734
+ label="Project Name", placeholder="my-training-project"
735
+ )
736
+
737
+ base_model = gr.Textbox(
738
+ label="Base Model", placeholder="distilbert-base-uncased"
739
+ )
740
+
741
+ dataset_path = gr.Textbox(label="Dataset Path", placeholder="imdb")
742
+
743
+ with gr.Column():
744
+ epochs = gr.Slider(1, 20, value=3, step=1, label="Epochs")
745
+ batch_size = gr.Slider(1, 128, value=16, step=1, label="Batch Size")
746
+ learning_rate = gr.Number(value=2e-5, label="Learning Rate")
747
+ backend = gr.Dropdown(
748
+ choices=["local", "spaces-t4-small", "spaces-a10g-large"],
749
+ label="Backend",
750
+ value="local",
751
+ )
752
+
753
+ submit_btn = gr.Button("🚀 Start Training", variant="primary", size="lg")
754
+ submit_output = gr.Textbox(label="Status", interactive=False, lines=10)
755
+
756
+ # MCP Info Tab
757
+ with gr.Tab("🔗 MCP Integration"):
758
+ gr.Markdown(f"""
759
+ ## MCP Server Information
760
+
761
+ This Gradio app automatically serves as an MCP server.
762
+
763
+ **MCP Endpoint:** `http://localhost:7860/gradio_api/mcp/sse`
764
+ **MCP Schema:** `http://localhost:7860/gradio_api/mcp/schema`
765
+
766
+ ### Available MCP Tools:
767
+
768
+ - `start_training_job` - Submit new training jobs
769
+ - `get_training_runs` - List all runs with status
770
+ - `get_run_details` - Get detailed run information
771
+ - `delete_training_run` - Delete training runs
772
+ - `get_task_recommendations` - Get training recommendations
773
+ - `get_system_status` - Check system status
774
+
775
+ ### Claude Desktop Configuration:
776
+
777
+ ```json
778
+ {{
779
+ "mcpServers": {{
780
+ "autotrain": {{
781
+ "url": "http://localhost:7860/gradio_api/mcp/sse"
782
+ }}
783
+ }}
784
+ }}
785
+ ```
786
+
787
+ ### Current Stats:
788
+
789
+ Total Runs: {len(load_runs())}
790
+ W&B Project: {WANDB_PROJECT}
791
+ """)
792
+
793
+ # MCP Tools Tab
794
+ with gr.Tab("🔧 MCP Tools"):
795
+ gr.Markdown("## MCP Tool Testing Interface")
796
+ gr.Markdown("These tools are exposed via MCP for Claude Desktop")
797
+
798
+ gr.Interface(
799
+ fn=get_system_status,
800
+ inputs=[],
801
+ outputs=gr.Textbox(label="System Status"),
802
+ title="get_system_status",
803
+ description="Get AutoTrain system status and capabilities",
804
+ )
805
+
806
+ gr.Interface(
807
+ fn=get_training_runs,
808
+ inputs=[
809
+ gr.Textbox(label="limit", value="20"),
810
+ gr.Textbox(label="status", value=""),
811
+ ],
812
+ outputs=gr.Textbox(label="Training Runs"),
813
+ title="get_training_runs",
814
+ description="Get list of training runs with status",
815
+ )
816
+
817
+ gr.Interface(
818
+ fn=start_training_job,
819
+ inputs=[
820
+ gr.Textbox(label="task", value="text-classification"),
821
+ gr.Textbox(label="project_name", value="test-project"),
822
+ gr.Textbox(label="base_model", value="distilbert-base-uncased"),
823
+ gr.Textbox(label="dataset_path", value="imdb"),
824
+ gr.Textbox(label="epochs", value="1"),
825
+ gr.Textbox(label="batch_size", value="8"),
826
+ gr.Textbox(label="learning_rate", value="2e-5"),
827
+ gr.Textbox(label="backend", value="local"),
828
+ ],
829
+ outputs=gr.Textbox(label="Training Job Result"),
830
+ title="start_training_job",
831
+ description="Start a new AutoTrain training job",
832
+ )
833
+
834
+ gr.Interface(
835
+ fn=get_run_details,
836
+ inputs=gr.Textbox(
837
+ label="run_id", placeholder="Enter run ID or first 8 chars"
838
+ ),
839
+ outputs=gr.Textbox(label="Run Details"),
840
+ title="get_run_details",
841
+ description="Get detailed information about a training run",
842
+ )
843
+
844
+ gr.Interface(
845
+ fn=get_task_recommendations,
846
+ inputs=[
847
+ gr.Textbox(label="task", value="text-classification"),
848
+ gr.Textbox(label="dataset_size", value="medium"),
849
+ ],
850
+ outputs=gr.Textbox(label="Recommendations"),
851
+ title="get_task_recommendations",
852
+ description="Get training recommendations for a task",
853
+ )
854
+
855
+ # Event handlers with proper function names (not lambda)
856
+ def refresh_data():
857
+ return fetch_runs_for_ui(), get_system_status()
858
+
859
+ def load_initial_data():
860
+ return fetch_runs_for_ui(), get_system_status()
861
+
862
+ refresh_btn.click(
863
+ fn=refresh_data,
864
+ outputs=[runs_table, stats],
865
+ )
866
+
867
+ submit_btn.click(
868
+ fn=submit_training_job_ui,
869
+ inputs=[
870
+ task_dropdown,
871
+ project_name,
872
+ base_model,
873
+ dataset_path,
874
+ epochs,
875
+ batch_size,
876
+ learning_rate,
877
+ backend,
878
+ ],
879
+ outputs=[submit_output, runs_table],
880
+ )
881
+
882
+ # Load initial data
883
+ app.load(
884
+ fn=load_initial_data,
885
+ outputs=[runs_table, stats],
886
+ )
887
+
888
+
889
+ # Helper to find an available port
890
+ def _find_available_port(start_port: int = 7860, max_tries: int = 20) -> int:
891
+ """Return the first available port starting from `start_port`."""
892
+ port = start_port
893
+ for _ in range(max_tries):
894
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
895
+ try:
896
+ s.bind(("0.0.0.0", port))
897
+ return port # Port is free
898
+ except OSError:
899
+ port += 1 # Try next port
900
+ # If no port found, let OS pick one
901
+ return 0
902
+
903
+
904
+ if __name__ == "__main__":
905
+ chosen_port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
906
+ try:
907
+ chosen_port = _find_available_port(chosen_port)
908
+ except Exception:
909
+ # Fallback to OS-assigned port if something goes wrong
910
+ chosen_port = 0
911
+
912
+ app.launch(
913
+ server_name="0.0.0.0",
914
+ server_port=chosen_port,
915
+ mcp_server=True, # Enable MCP server functionality
916
+ )
pyproject.toml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "autotrain-gradio-mcp"
3
+ version = "0.1.0"
4
+ description = "AutoTrain Gradio MCP Server - All-in-One Solution"
5
+ authors = [
6
+ {name = "AutoTrain Team", email = "autotrain@example.com"}
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ # Core dependencies
12
+ "gradio[mcp]>=5.0.0",
13
+ "autotrain-advanced>=0.8.0",
14
+ "pandas>=2.0.0",
15
+ "wandb>=0.16.0",
16
+
17
+ # MCP and async support
18
+ "httpx>=0.25.0",
19
+ "aiofiles>=23.0.0",
20
+
21
+ # Data handling
22
+ "datasets>=2.0.0",
23
+ "torch>=2.0.0",
24
+ "transformers>=4.30.0",
25
+
26
+ # Optional ML frameworks
27
+ "accelerate>=0.20.0",
28
+ "peft>=0.4.0",
29
+ "bitsandbytes>=0.41.0",
30
+ ]
31
+
32
+ [project.optional-dependencies]
33
+ dev = [
34
+ "pytest>=7.0.0",
35
+ "black>=23.0.0",
36
+ "flake8>=6.0.0",
37
+ "mypy>=1.0.0",
38
+ ]
39
+
40
+ [build-system]
41
+ requires = ["setuptools>=65.0", "wheel"]
42
+ build-backend = "setuptools.build_meta"
43
+
44
+ [project.scripts]
45
+ autotrain-gradio = "autotrain_gradio:main"
46
+
47
+ [tool.black]
48
+ line-length = 88
49
+ target-version = ['py310']
50
+
51
+ [tool.mypy]
52
+ python_version = "3.10"
53
+ warn_return_any = true
54
+ warn_unused_configs = true
requirements.txt ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv export --format requirements-txt --no-hashes
3
+ -e .
4
+ absl-py==2.3.0
5
+ accelerate==1.2.1
6
+ aiofiles==23.2.1
7
+ aiohappyeyeballs==2.6.1
8
+ aiohttp==3.12.9
9
+ aiosignal==1.3.2
10
+ albucore==0.0.21
11
+ albumentations==1.4.23
12
+ alembic==1.16.1
13
+ annotated-types==0.7.0
14
+ anyio==4.9.0
15
+ async-timeout==5.0.1 ; python_full_version < '3.11'
16
+ attrs==25.3.0
17
+ audioop-lts==0.2.1 ; python_full_version >= '3.13'
18
+ authlib==1.4.0
19
+ bitsandbytes==0.45.0
20
+ brotli==1.1.0 ; platform_python_implementation == 'CPython'
21
+ brotlicffi==1.1.0.0 ; platform_python_implementation == 'PyPy'
22
+ cachetools==6.0.0
23
+ certifi==2025.4.26
24
+ cffi==1.17.1
25
+ charset-normalizer==3.4.2
26
+ click==8.2.1
27
+ colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows'
28
+ colorlog==6.9.0
29
+ contourpy==1.3.2
30
+ cryptography==44.0.0
31
+ cycler==0.12.1
32
+ datasets==3.2.0
33
+ dill==0.3.8
34
+ einops==0.8.0
35
+ eval-type-backport==0.2.2
36
+ evaluate==0.4.3
37
+ exceptiongroup==1.3.0 ; python_full_version < '3.11'
38
+ fastapi==0.115.6
39
+ ffmpy==0.6.0
40
+ filelock==3.18.0
41
+ fonttools==4.58.1
42
+ frozenlist==1.6.2
43
+ fsspec==2024.9.0
44
+ gitdb==4.0.12
45
+ gitpython==3.1.44
46
+ gradio>=5.33.0
47
+ gradio-client==1.7.0
48
+ greenlet==3.2.3 ; (python_full_version < '3.14' and platform_machine == 'AMD64') or (python_full_version < '3.14' and platform_machine == 'WIN32') or (python_full_version < '3.14' and platform_machine == 'aarch64') or (python_full_version < '3.14' and platform_machine == 'amd64') or (python_full_version < '3.14' and platform_machine == 'ppc64le') or (python_full_version < '3.14' and platform_machine == 'win32') or (python_full_version < '3.14' and platform_machine == 'x86_64')
49
+ grpcio==1.72.1
50
+ h11==0.16.0
51
+ hf-transfer==0.1.9
52
+ httpcore==1.0.9
53
+ httpx==0.28.1
54
+ huggingface-hub==0.27.0
55
+ idna==3.10
56
+ inflate64==1.0.3
57
+ ipadic==1.0.0
58
+ itsdangerous==2.2.0
59
+ jinja2==3.1.6
60
+ jiwer==3.0.5
61
+ joblib==1.4.2
62
+ kiwisolver==1.4.8
63
+ lightning-utilities==0.14.3
64
+ loguru==0.7.3
65
+ mako==1.3.10
66
+ markdown==3.8
67
+ markdown-it-py==3.0.0
68
+ markupsafe==2.1.5
69
+ matplotlib==3.10.3
70
+ mdurl==0.1.2
71
+ mpmath==1.3.0
72
+ multidict==6.4.4
73
+ multiprocess==0.70.16
74
+ multivolumefile==0.2.3
75
+ networkx==3.4.2 ; python_full_version < '3.11'
76
+ networkx==3.5 ; python_full_version >= '3.11'
77
+ nltk==3.9.1
78
+ numpy==2.2.6
79
+ nvidia-cublas-cu12==12.6.4.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
80
+ nvidia-cuda-cupti-cu12==12.6.80 ; platform_machine == 'x86_64' and platform_system == 'Linux'
81
+ nvidia-cuda-nvrtc-cu12==12.6.77 ; platform_machine == 'x86_64' and platform_system == 'Linux'
82
+ nvidia-cuda-runtime-cu12==12.6.77 ; platform_machine == 'x86_64' and platform_system == 'Linux'
83
+ nvidia-cudnn-cu12==9.5.1.17 ; platform_machine == 'x86_64' and platform_system == 'Linux'
84
+ nvidia-cufft-cu12==11.3.0.4 ; platform_machine == 'x86_64' and platform_system == 'Linux'
85
+ nvidia-cufile-cu12==1.11.1.6 ; platform_machine == 'x86_64' and platform_system == 'Linux'
86
+ nvidia-curand-cu12==10.3.7.77 ; platform_machine == 'x86_64' and platform_system == 'Linux'
87
+ nvidia-cusolver-cu12==11.7.1.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
88
+ nvidia-cusparse-cu12==12.5.4.2 ; platform_machine == 'x86_64' and platform_system == 'Linux'
89
+ nvidia-cusparselt-cu12==0.6.3 ; platform_machine == 'x86_64' and platform_system == 'Linux'
90
+ nvidia-ml-py==12.535.161
91
+ nvidia-nccl-cu12==2.26.2 ; platform_machine != 'aarch64' and platform_system == 'Linux'
92
+ nvidia-nvjitlink-cu12==12.6.85 ; platform_machine == 'x86_64' and platform_system == 'Linux'
93
+ nvidia-nvtx-cu12==12.6.77 ; platform_machine == 'x86_64' and platform_system == 'Linux'
94
+ nvitop==1.3.2
95
+ opencv-python-headless==4.11.0.86
96
+ optuna==4.1.0
97
+ orjson==3.10.18
98
+ packaging==24.2
99
+ pandas==2.2.3
100
+ peft==0.14.0
101
+ pillow==11.0.0
102
+ platformdirs==4.3.8
103
+ propcache==0.3.1
104
+ protobuf==6.31.1
105
+ psutil==7.0.0
106
+ py7zr==0.22.0
107
+ pyarrow==20.0.0
108
+ pybcj==1.0.6
109
+ pycocotools==2.0.8
110
+ pycparser==2.22
111
+ pycryptodomex==3.23.0
112
+ pydantic==2.10.4
113
+ pydantic-core==2.27.2
114
+ pydub==0.25.1
115
+ pygments==2.19.1
116
+ pyngrok==7.2.1
117
+ pyparsing==3.2.3
118
+ pyppmd==1.1.1
119
+ python-dateutil==2.9.0.post0
120
+ python-multipart==0.0.20
121
+ pytz==2025.2
122
+ pyyaml==6.0.2
123
+ pyzstd==0.17.0
124
+ rapidfuzz==3.13.0
125
+ regex==2024.11.6
126
+ requests==2.32.3
127
+ rich==14.0.0
128
+ rouge-score==0.1.2
129
+ ruff==0.11.13 ; sys_platform != 'emscripten'
130
+ sacremoses==0.1.1
131
+ safehttpx==0.1.6
132
+ safetensors==0.5.3
133
+ scikit-learn==1.6.0
134
+ scipy==1.15.3
135
+ semantic-version==2.10.0
136
+ sentence-transformers==3.3.1
137
+ sentencepiece==0.2.0
138
+ sentry-sdk==2.29.1
139
+ seqeval==1.2.2
140
+ setproctitle==1.3.6
141
+ setuptools==80.9.0
142
+ shellingham==1.5.4 ; sys_platform != 'emscripten'
143
+ simsimd==6.4.7
144
+ six==1.17.0
145
+ smmap==5.0.2
146
+ sniffio==1.3.1
147
+ sqlalchemy==2.0.41
148
+ starlette==0.41.3
149
+ stringzilla==3.12.5
150
+ sympy==1.14.0
151
+ tensorboard==2.18.0
152
+ tensorboard-data-server==0.7.2
153
+ termcolor==3.1.0
154
+ texttable==1.7.0
155
+ threadpoolctl==3.6.0
156
+ tiktoken==0.8.0
157
+ timm==1.0.12
158
+ tokenizers==0.21.1
159
+ tomli==2.2.1 ; python_full_version < '3.11'
160
+ tomlkit==0.13.3
161
+ torch==2.7.1
162
+ torchmetrics==1.6.0
163
+ torchvision==0.22.1
164
+ tqdm==4.67.1
165
+ transformers==4.48.0
166
+ triton==3.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
167
+ trl==0.13.0
168
+ typer==0.16.0 ; sys_platform != 'emscripten'
169
+ typing-extensions==4.14.0
170
+ tzdata==2025.2
171
+ urllib3==2.4.0
172
+ uvicorn==0.34.0
173
+ wandb==0.20.1
174
+ websockets==14.2
175
+ werkzeug==3.1.3
176
+ win32-setctime==1.2.0 ; sys_platform == 'win32'
177
+ windows-curses==2.4.1 ; platform_system == 'Windows'
178
+ xgboost==2.1.3
179
+ xxhash==3.5.0
180
+ yarl==1.20.0
181
+ git+https://github.com/huggingface/autotrain-advanced.git
182
+
uv.lock ADDED
The diff for this file is too large to render. See raw diff