awacke1 commited on
Commit
f873e60
·
verified ·
1 Parent(s): 1084118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -51
app.py CHANGED
@@ -4,16 +4,15 @@ from transformers import pipeline
4
  import os
5
 
6
  # --- App Configuration ---
7
- TITLE = "✍️ AI Story Weaver"
8
  DESCRIPTION = """
9
- Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models.
10
- This app uses:
11
- - **Mistral-7B-Instruct-v0.2**
12
- - **Google's Gemma-7B-IT**
13
- - **Meta's Llama-3-8B-Instruct**
14
-
15
- **⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB).
16
- The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate`
17
  """
18
 
19
  # --- Example Prompts for Storytelling ---
@@ -26,52 +25,84 @@ examples = [
26
  ]
27
 
28
  # --- Model Initialization ---
29
- # This section loads the models. It requires significant hardware resources.
30
- # `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision.
31
  try:
32
- print("Initializing models... This may take several minutes.")
33
 
34
- # NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement.
35
- # from huggingface_hub import login
36
- # login("YOUR_HF_TOKEN")
37
-
38
- generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto")
39
- print("✅ Mistral-7B loaded.")
40
 
41
- generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto")
42
- print("✅ Gemma-7B loaded.")
43
-
44
- generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto")
45
- print("✅ Llama-3-8B loaded.")
46
-
47
- print("All models loaded successfully! 🎉")
 
 
48
 
49
  except Exception as e:
50
  print(f"--- 🚨 Error loading models ---")
51
  print(f"Error: {e}")
52
- print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.")
53
  # Create a dummy function if models fail, so the app can still launch with an error message.
54
  def failed_generator(prompt, **kwargs):
55
- return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}]
56
- generator1 = generator2 = generator3 = failed_generator
 
57
 
58
 
59
  # --- App Logic ---
60
- def generate_stories(prompt: str) -> tuple[str, str, str]:
61
- """Generates text from the three loaded models based on the user's prompt."""
 
 
62
  if not prompt:
63
- return "Please enter a prompt to start.", "", ""
64
-
65
- # We use 'max_new_tokens' to control the length of the generated story.
66
- # Increased to 200 for more substantial story continuations.
67
- params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95}
68
-
69
- # Generate from all three models
70
- out1 = generator1(prompt, **params)[0]['generated_text']
71
- out2 = generator2(prompt, **params)[0]['generated_text']
72
- out3 = generator3(prompt, **params)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- return out1, out2, out3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # --- Gradio Interface ---
77
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
@@ -85,16 +116,22 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !i
85
  label="Your Story Prompt 👇",
86
  placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
87
  )
88
- generate_button = gr.Button("Weave a Story ✨", variant="primary")
 
 
 
89
 
90
- with gr.Column(scale=2):
91
- with gr.Tabs():
92
- with gr.TabItem("Mistral-7B"):
93
- gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12)
94
- with gr.TabItem("Gemma-7B"):
95
- gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12)
96
- with gr.TabItem("Llama-3-8B"):
97
- gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12)
 
 
 
98
 
99
  gr.Examples(
100
  examples=examples,
@@ -105,7 +142,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !i
105
  generate_button.click(
106
  fn=generate_stories,
107
  inputs=input_area,
108
- outputs=[gen1_output, gen2_output, gen3_output],
109
  api_name="generate"
110
  )
111
 
 
4
  import os
5
 
6
  # --- App Configuration ---
7
+ TITLE = "✍️ AI Story Outliner"
8
  DESCRIPTION = """
9
+ Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model.
10
+ The app uses **TinyLlama-1.1B** to generate creative outlines formatted in Markdown.
11
+
12
+ **How it works:**
13
+ 1. Enter your story idea.
14
+ 2. The AI will generate 10 different story outlines.
15
+ 3. Each outline has a dramatic beginning and is concise, like a song.
 
16
  """
17
 
18
  # --- Example Prompts for Storytelling ---
 
25
  ]
26
 
27
  # --- Model Initialization ---
28
+ # This section loads a smaller, CPU-friendly model.
29
+ # It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces.
30
  try:
31
+ print("Initializing model... This may take a moment.")
32
 
33
+ # Load the token from environment variables if it exists (for HF Spaces secrets)
34
+ hf_token = os.environ.get("HF_TOKEN", None)
 
 
 
 
35
 
36
+ # Using a smaller model that is more suitable for running without a high-end GPU.
37
+ generator = pipeline(
38
+ "text-generation",
39
+ model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
40
+ torch_dtype=torch.bfloat16, # More efficient dtype
41
+ device_map="auto", # Will use GPU if available, otherwise CPU
42
+ token=hf_token
43
+ )
44
+ print("✅ TinyLlama model loaded successfully!")
45
 
46
  except Exception as e:
47
  print(f"--- 🚨 Error loading models ---")
48
  print(f"Error: {e}")
 
49
  # Create a dummy function if models fail, so the app can still launch with an error message.
50
  def failed_generator(prompt, **kwargs):
51
+ error_message = f"Model failed to load. Please check the console for errors. Error: {e}"
52
+ return [{'generated_text': error_message}]
53
+ generator = failed_generator
54
 
55
 
56
  # --- App Logic ---
57
+ def generate_stories(prompt: str) -> list[str]:
58
+ """
59
+ Generates 10 story outlines from the loaded model based on the user's prompt.
60
+ """
61
  if not prompt:
62
+ # Return a list of 10 empty strings to clear the outputs
63
+ return [""] * 10
64
+
65
+ # A detailed system prompt to guide the model's output format and structure.
66
+ system_prompt = f"""
67
+ <|system|>
68
+ You are an expert storyteller. Your task is to take a user's prompt and write
69
+ a short story as a Markdown outline. The story must have a dramatic arc and be
70
+ the length of a song. Use emojis to highlight the story sections.
71
+
72
+ **Your Story Outline Structure:**
73
+ - 🎬 **The Hook:** A dramatic opening.
74
+ - 🎼 **The Ballad:** The main story, told concisely.
75
+ - 🔚 **The Finale:** A clear and satisfying ending.</s>
76
+ <|user|>
77
+ {prompt}</s>
78
+ <|assistant|>
79
+ """
80
+
81
+ # Parameters for the pipeline to generate 10 diverse results.
82
+ params = {
83
+ "max_new_tokens": 250,
84
+ "num_return_sequences": 10,
85
+ "do_sample": True,
86
+ "temperature": 0.8,
87
+ "top_k": 50,
88
+ "top_p": 0.95,
89
+ }
90
 
91
+ # Generate 10 different story variations
92
+ outputs = generator(system_prompt, **params)
93
+
94
+ # Extract the generated text and clean it up.
95
+ stories = []
96
+ for out in outputs:
97
+ # Remove the system prompt from the beginning of the output
98
+ cleaned_text = out['generated_text'].replace(system_prompt, "").strip()
99
+ stories.append(cleaned_text)
100
+
101
+ # Ensure we return exactly 10 stories, padding with an error message if necessary.
102
+ while len(stories) < 10:
103
+ stories.append("Failed to generate a story for this slot.")
104
+
105
+ return stories
106
 
107
  # --- Gradio Interface ---
108
  with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
 
116
  label="Your Story Prompt 👇",
117
  placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
118
  )
119
+ generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary")
120
+
121
+ gr.Markdown("---")
122
+ gr.Markdown("## 📖 Your 10 Story Outlines")
123
 
124
+ # Create 10 markdown components to display the stories in two columns
125
+ story_outputs = []
126
+ with gr.Row():
127
+ with gr.Column():
128
+ for i in range(5):
129
+ md = gr.Markdown(label=f"Story Outline {i + 1}")
130
+ story_outputs.append(md)
131
+ with gr.Column():
132
+ for i in range(5, 10):
133
+ md = gr.Markdown(label=f"Story Outline {i + 1}")
134
+ story_outputs.append(md)
135
 
136
  gr.Examples(
137
  examples=examples,
 
142
  generate_button.click(
143
  fn=generate_stories,
144
  inputs=input_area,
145
+ outputs=story_outputs,
146
  api_name="generate"
147
  )
148