awacke1 commited on
Commit
95de190
·
verified ·
1 Parent(s): f873e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -39
app.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  TITLE = "✍️ AI Story Outliner"
8
  DESCRIPTION = """
9
  Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model.
10
- The app uses **TinyLlama-1.1B** to generate creative outlines formatted in Markdown.
11
 
12
  **How it works:**
13
  1. Enter your story idea.
@@ -25,32 +25,26 @@ examples = [
25
  ]
26
 
27
  # --- Model Initialization ---
28
- # This section loads a smaller, CPU-friendly model.
29
- # It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces.
 
 
30
  try:
31
  print("Initializing model... This may take a moment.")
32
 
33
- # Load the token from environment variables if it exists (for HF Spaces secrets)
34
- hf_token = os.environ.get("HF_TOKEN", None)
35
-
36
- # Using a smaller model that is more suitable for running without a high-end GPU.
37
  generator = pipeline(
38
  "text-generation",
39
- model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
40
- torch_dtype=torch.bfloat16, # More efficient dtype
41
- device_map="auto", # Will use GPU if available, otherwise CPU
42
- token=hf_token
43
  )
44
- print("✅ TinyLlama model loaded successfully!")
45
 
46
  except Exception as e:
47
- print(f"--- 🚨 Error loading models ---")
48
- print(f"Error: {e}")
49
- # Create a dummy function if models fail, so the app can still launch with an error message.
50
- def failed_generator(prompt, **kwargs):
51
- error_message = f"Model failed to load. Please check the console for errors. Error: {e}"
52
- return [{'generated_text': error_message}]
53
- generator = failed_generator
54
 
55
 
56
  # --- App Logic ---
@@ -58,47 +52,54 @@ def generate_stories(prompt: str) -> list[str]:
58
  """
59
  Generates 10 story outlines from the loaded model based on the user's prompt.
60
  """
 
 
 
 
 
61
  if not prompt:
62
  # Return a list of 10 empty strings to clear the outputs
63
  return [""] * 10
64
 
65
  # A detailed system prompt to guide the model's output format and structure.
66
- system_prompt = f"""
67
- <|system|>
68
- You are an expert storyteller. Your task is to take a user's prompt and write
69
- a short story as a Markdown outline. The story must have a dramatic arc and be
70
- the length of a song. Use emojis to highlight the story sections.
71
-
72
- **Your Story Outline Structure:**
73
- - 🎬 **The Hook:** A dramatic opening.
74
- - 🎼 **The Ballad:** The main story, told concisely.
75
- - 🔚 **The Finale:** A clear and satisfying ending.</s>
76
- <|user|>
77
- {prompt}</s>
78
- <|assistant|>
 
 
79
  """
80
 
81
  # Parameters for the pipeline to generate 10 diverse results.
82
  params = {
83
- "max_new_tokens": 250,
84
  "num_return_sequences": 10,
85
  "do_sample": True,
86
- "temperature": 0.8,
87
  "top_k": 50,
88
  "top_p": 0.95,
89
  }
90
 
91
  # Generate 10 different story variations
92
- outputs = generator(system_prompt, **params)
93
 
94
  # Extract the generated text and clean it up.
95
  stories = []
96
  for out in outputs:
97
- # Remove the system prompt from the beginning of the output
98
- cleaned_text = out['generated_text'].replace(system_prompt, "").strip()
99
- stories.append(cleaned_text)
100
 
101
- # Ensure we return exactly 10 stories, padding with an error message if necessary.
102
  while len(stories) < 10:
103
  stories.append("Failed to generate a story for this slot.")
104
 
 
7
  TITLE = "✍️ AI Story Outliner"
8
  DESCRIPTION = """
9
  Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model.
10
+ The app uses **DistilGPT-2** to generate creative outlines formatted in Markdown.
11
 
12
  **How it works:**
13
  1. Enter your story idea.
 
25
  ]
26
 
27
  # --- Model Initialization ---
28
+ # This section loads a smaller, CPU-friendly model that does not require a token.
29
+ generator = None
30
+ model_error = None
31
+
32
  try:
33
  print("Initializing model... This may take a moment.")
34
 
35
+ # Using a smaller, fully open model that does not require authentication.
 
 
 
36
  generator = pipeline(
37
  "text-generation",
38
+ model="distilgpt2",
39
+ torch_dtype=torch.float32, # Use float32 for wider CPU compatibility
40
+ device_map="auto" # Will use GPU if available, otherwise CPU
 
41
  )
42
+ print("✅ DistilGPT-2 model loaded successfully!")
43
 
44
  except Exception as e:
45
+ model_error = e
46
+ print(f"--- 🚨 Error loading model ---")
47
+ print(f"Error: {model_error}")
 
 
 
 
48
 
49
 
50
  # --- App Logic ---
 
52
  """
53
  Generates 10 story outlines from the loaded model based on the user's prompt.
54
  """
55
+ # If the model failed to load, display the error in all output boxes.
56
+ if model_error:
57
+ error_message = f"**Model failed to load.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(model_error)}`"
58
+ return [error_message] * 10
59
+
60
  if not prompt:
61
  # Return a list of 10 empty strings to clear the outputs
62
  return [""] * 10
63
 
64
  # A detailed system prompt to guide the model's output format and structure.
65
+ # Note: Simpler models like DistilGPT-2 may not follow complex instructions perfectly.
66
+ story_prompt = f"""
67
+ Story Idea: "{prompt}"
68
+
69
+ Create a short story outline based on this idea.
70
+
71
+ ### The Hook
72
+ A dramatic opening.
73
+
74
+ ### The Ballad
75
+ The main story, told concisely.
76
+
77
+ ### The Finale
78
+ A clear and satisfying ending.
79
+ ---
80
  """
81
 
82
  # Parameters for the pipeline to generate 10 diverse results.
83
  params = {
84
+ "max_new_tokens": 200,
85
  "num_return_sequences": 10,
86
  "do_sample": True,
87
+ "temperature": 0.9,
88
  "top_k": 50,
89
  "top_p": 0.95,
90
  }
91
 
92
  # Generate 10 different story variations
93
+ outputs = generator(story_prompt, **params)
94
 
95
  # Extract the generated text and clean it up.
96
  stories = []
97
  for out in outputs:
98
+ # The model will generate the prompt plus the continuation, so we format it.
99
+ full_text = out['generated_text']
100
+ stories.append(full_text)
101
 
102
+ # Ensure we return exactly 10 stories, padding if necessary.
103
  while len(stories) < 10:
104
  stories.append("Failed to generate a story for this slot.")
105