AryanRathod3097 commited on
Commit
d266dc0
Β·
verified Β·
1 Parent(s): 14f8553

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -84
app.py CHANGED
@@ -1,95 +1,96 @@
1
  """
2
- Tiny-CodeNyx – 160 MB distilled general-knowledge code model
3
- Fine-tuned on 5k Q&A snippets in < 2 min
 
 
 
4
  """
5
- import os, json, torch, gradio as gr
6
- from datasets import load_dataset
7
- from transformers import (AutoTokenizer, AutoModelForCausalLM,
8
- Trainer, TrainingArguments, DataCollatorForLanguageModeling)
9
- from peft import LoraConfig, get_peft_model
10
-
11
- MODEL_ID = "distilgpt2"
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
- tokenizer.pad_token = tokenizer.eos_token
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
15
-
16
- # ---------- 1. 5k-shot general-knowledge dataset ----------
17
- def build_mini_dataset():
18
- """Return a tiny JSON that mixes code & general facts."""
19
- data = [
20
- {"text": "Q: Write a FastAPI route that returns current UTC time.\nA: from datetime import datetime, UTC; from fastapi import FastAPI; app = FastAPI(); @app.get('/time'); def get_time(): return {'utc': datetime.now(UTC).isoformat()}"},
21
- {"text": "Q: Capital of France?\nA: Paris"},
22
- {"text": "Q: Print Fibonacci sequence in Python.\nA: a,b=0,1;[print(a)or(a:=b,b:=a+b)for _ in range(10)]"},
23
- {"text": "Q: What is 2+2?\nA: 4"},
24
- {"text": "Q: Explain list comprehension.\nA: [expr for item in iterable if condition]"},
25
- {"text": "Q: Who wrote Romeo and Juliet?\nA: William Shakespeare"},
26
- {"text": "Q: How to reverse a string in Python?\nA: s[::-1]"},
27
- {"text": "Q: Largest planet?\nA: Jupiter"},
28
- {"text": "Q: SQL to create users table.\nA: CREATE TABLE users(id INT PRIMARY KEY, name VARCHAR(100));"},
29
- {"text": "Q: Speed of light in vacuum?\nA: 299 792 458 m/s"},
30
- ]
31
- # replicate to 5 000 lines
32
- data = data * 500
33
- with open("mini.json", "w") as f:
34
- for d in data:
35
- f.write(json.dumps(d) + "\n")
36
- return load_dataset("json", data_files="mini.json")["train"]
37
 
38
- dataset = build_mini_dataset()
39
-
40
- # ---------- 2. Tokenize ----------
41
- def tokenize(examples):
42
- return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
43
-
44
- dataset = dataset.map(tokenize, batched=True)
45
- data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
46
 
47
- # ---------- 3. LoRA fine-tune ----------
48
- lora_config = LoraConfig(
49
- r=8, lora_alpha=32, lora_dropout=0.1, target_modules=["c_attn"]
 
 
 
 
 
50
  )
51
- model = get_peft_model(model, lora_config)
52
 
53
- training_args = TrainingArguments(
54
- output_dir="./tiny-codenyx",
55
- per_device_train_batch_size=4,
56
- num_train_epochs=1,
57
- logging_steps=50,
58
- fp16=True,
59
- save_steps=500,
60
- save_total_limit=1,
61
- report_to=None,
 
62
  )
63
- trainer = Trainer(
64
- model=model,
65
- args=training_args,
66
- train_dataset=dataset,
67
- data_collator=data_collator,
68
  )
69
- trainer.train()
70
- trainer.save_model("./tiny-codenyx")
 
71
 
72
- # ---------- 4. Gradio chat ----------
73
- model.eval()
74
- def chat_fn(message, history):
75
- prompt = "\n".join([f"Q: {h[0]}\nA: {h[1]}" for h in history])
76
- prompt += f"\nQ: {message}\nA:"
77
- inputs = tokenizer.encode(prompt, return_tensors="pt")
78
- with torch.no_grad():
79
- outputs = model.generate(
80
- inputs,
81
- max_new_tokens=128,
82
- temperature=0.7,
83
- do_sample=True,
84
- pad_token_id=tokenizer.eos_token_id,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
- answer = answer.split("A:")[-1].strip()
88
- return answer
 
 
 
 
 
 
 
 
89
 
90
- gr.ChatInterface(
91
- fn=chat_fn,
92
- title="Tiny-CodeNyx – 160 MB General-Knowledge Bot",
93
- description="Ask anything code or general knowledge; model trained on 5k Q&A.",
94
- theme="soft"
95
- ).queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  """
2
+ RealCanvas-MJ4K
3
+ A 16-GB-friendly Gradio Space that
4
+ 1. streams the prompt dataset MohamedRashad/midjourney-detailed-prompts
5
+ 2. generates realistic images using SDXL-Lightning
6
+ 3. optionally displays random images from opendiffusionai/cc12m-4mp-realistic
7
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ import gradio as gr
10
+ import torch, os, random, json, requests
11
+ from io import BytesIO
12
+ from PIL import Image
13
+ from datasets import load_dataset
14
+ from huggingface_hub import hf_hub_download
15
+ from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
 
16
 
17
+ # -------------------------------------------------
18
+ # 1. Load the prompt dataset (lazy streaming)
19
+ # -------------------------------------------------
20
+ print("πŸ” Streaming prompt dataset …")
21
+ ds_prompts = load_dataset(
22
+ "MohamedRashad/midjourney-detailed-prompts",
23
+ split="train",
24
+ streaming=True
25
  )
26
+ prompt_pool = list(ds_prompts.shuffle(seed=42).take(500_000)) # β‰ˆ 5 MB RAM
27
 
28
+ # -------------------------------------------------
29
+ # 2. Load SDXL-Lightning (fp16, 4-step, 4 GB VRAM)
30
+ # -------------------------------------------------
31
+ MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
32
+ print("πŸ€– Loading SDXL-Lightning …")
33
+ pipe = StableDiffusionXLPipeline.from_pretrained(
34
+ MODEL_ID,
35
+ torch_dtype=torch.float16,
36
+ variant="fp16",
37
+ use_safetensors=True
38
  )
39
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
40
+ # lightning LoRA
41
+ lora_path = hf_hub_download(
42
+ repo_id="ByteDance/SDXL-Lightning",
43
+ filename="sdxl_lightning_4step_lora.safetensors"
44
  )
45
+ pipe.load_lora_weights(lora_path)
46
+ pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
47
+ pipe.enable_attention_slicing()
48
 
49
+ # -------------------------------------------------
50
+ # 3. Random CC12M-4MP image helper (optional demo)
51
+ # -------------------------------------------------
52
+ print("πŸ“Έ Streaming CC12M-4MP-realistic …")
53
+ ds_images = load_dataset(
54
+ "opendiffusionai/cc12m-4mp-realistic",
55
+ split="train",
56
+ streaming=True
57
+ )
58
+ img_pool = list(ds_images.shuffle(seed=42).take(1_000)) # β‰ˆ 10 MB RAM
59
+
60
+ def random_cc12m_image():
61
+ sample = random.choice(img_pool)
62
+ return sample["image"].resize((512, 512))
63
+
64
+ # -------------------------------------------------
65
+ # 4. Gradio UI
66
+ # -------------------------------------------------
67
+ def generate(prompt: str, steps: int = 4, guidance: float = 0.0):
68
+ if not prompt.strip():
69
+ prompt = random.choice(prompt_pool)["prompt"]
70
+ image = pipe(
71
+ prompt,
72
+ num_inference_steps=steps,
73
+ guidance_scale=guidance
74
+ ).images[0]
75
+ return image.resize((768, 768))
76
+
77
+ with gr.Blocks(title="RealCanvas-MJ4K") as demo:
78
+ gr.Markdown("# 🎨 RealCanvas-MJ4K | Midjourney-level realism under 16 GB")
79
+ with gr.Row():
80
+ prompt_in = gr.Textbox(
81
+ label="Prompt (leave empty for random Midjourney-style prompt)",
82
+ lines=2
83
  )
84
+ with gr.Row():
85
+ steps = gr.Slider(1, 8, value=4, step=1, label="Inference steps (SDXL-Lightning)")
86
+ guidance = gr.Slider(0.0, 2.0, value=0.0, step=0.1, label="Guidance scale")
87
+ btn = gr.Button("Generate", variant="primary")
88
+ gallery = gr.Image(type="pil", label="Generated image")
89
+ with gr.Accordion("πŸ“Έ Random CC12M-4MP sample", open=False):
90
+ cc_btn = gr.Button("Show random CC12M-4MP image")
91
+ cc_out = gr.Image(type="pil", label="Real photo from dataset")
92
+
93
+ btn.click(generate, [prompt_in, steps, guidance], gallery)
94
+ cc_btn.click(random_cc12m_image, outputs=cc_out)
95
 
96
+ demo.queue(max_size=8).launch()