Jeff Myers II commited on
Commit
297702e
·
1 Parent(s): 5278642

Using quantization_config instead of load_in_8bit

Browse files
Files changed (1) hide show
  1. Gemma.py +8 -1
Gemma.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
2
  from huggingface_hub import login
3
  import spaces
4
  import torch
@@ -11,8 +12,14 @@ class GemmaLLM:
11
  def __init__(self):
12
  login(token=os.environ.get("GEMMA_TOKEN"))
13
 
 
 
 
 
 
 
14
  model_id = "google/gemma-3-4b-it"
15
- model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
 
18
  self.model = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
 
1
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
2
+ from transformers.utils import quantization_config
3
  from huggingface_hub import login
4
  import spaces
5
  import torch
 
12
  def __init__(self):
13
  login(token=os.environ.get("GEMMA_TOKEN"))
14
 
15
+ quant_config = quantization_config.BitsAndBytesConfig(
16
+ load_in_8bit=True,
17
+ llm_int8_threshold=6.0,
18
+ llm_int8_has_fp16_weight=False,
19
+ )
20
+
21
  model_id = "google/gemma-3-4b-it"
22
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config)
23
  tokenizer = AutoTokenizer.from_pretrained(model_id)
24
 
25
  self.model = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16)