Jeff Myers II commited on
Commit
32b2ab9
·
1 Parent(s): 7ba0657

Attempting to enable 8-bit quantization

Browse files
Files changed (2) hide show
  1. Gemma.py +4 -4
  2. requirements.txt +1 -0
Gemma.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import pipeline
2
  from huggingface_hub import login
3
  import spaces
4
  import torch
@@ -8,13 +8,13 @@ import os
8
  __export__ = ["GemmaLLM"]
9
 
10
  class GemmaLLM:
11
-
12
  def __init__(self):
13
  login(token=os.environ.get("GEMMA_TOKEN"))
14
 
15
  model_id = "google/gemma-3-4b-it"
16
-
17
- self.model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, device="cuda")
 
18
 
19
  @spaces.GPU
20
  def generate(self, message) -> str:
 
1
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
2
  from huggingface_hub import login
3
  import spaces
4
  import torch
 
8
  __export__ = ["GemmaLLM"]
9
 
10
  class GemmaLLM:
 
11
  def __init__(self):
12
  login(token=os.environ.get("GEMMA_TOKEN"))
13
 
14
  model_id = "google/gemma-3-4b-it"
15
+ model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, pad_token_id=0)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ self.model = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, device="auto")
18
 
19
  @spaces.GPU
20
  def generate(self, message) -> str:
requirements.txt CHANGED
@@ -5,4 +5,5 @@ newspaper3k==0.2.8
5
  transformers==4.50.0
6
  lxml_html_clean==0.4.1
7
  accelerate==1.5.2
 
8
  spaces
 
5
  transformers==4.50.0
6
  lxml_html_clean==0.4.1
7
  accelerate==1.5.2
8
+ bitsandbytes==0.45.3
9
  spaces