Jeff Myers II commited on
Commit
cb3e313
·
1 Parent(s): 297702e

Removed 8-bit quantization and changed model_id to google/gemma-3n-E4B-it-litert-preview

Browse files
Files changed (1) hide show
  1. Gemma.py +11 -12
Gemma.py CHANGED
@@ -1,8 +1,6 @@
1
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
2
- from transformers.utils import quantization_config
3
  from huggingface_hub import login
4
  import spaces
5
- import torch
6
  import json
7
  import os
8
 
@@ -12,17 +10,18 @@ class GemmaLLM:
12
  def __init__(self):
13
  login(token=os.environ.get("GEMMA_TOKEN"))
14
 
15
- quant_config = quantization_config.BitsAndBytesConfig(
16
- load_in_8bit=True,
17
- llm_int8_threshold=6.0,
18
- llm_int8_has_fp16_weight=False,
19
- )
20
 
21
- model_id = "google/gemma-3-4b-it"
22
- model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config)
23
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
24
 
25
- self.model = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16)
26
 
27
  @spaces.GPU
28
  def generate(self, message) -> str:
 
1
+ from transformers import pipeline
 
2
  from huggingface_hub import login
3
  import spaces
 
4
  import json
5
  import os
6
 
 
10
  def __init__(self):
11
  login(token=os.environ.get("GEMMA_TOKEN"))
12
 
13
+ # quant_config = quantization_config.BitsAndBytesConfig(
14
+ # load_in_8bit=True,
15
+ # llm_int8_threshold=6.0,
16
+ # llm_int8_has_fp16_weight=False,
17
+ # )
18
 
19
+ # model_id = "google/gemma-3-4b-it"
20
+ model_id = "google/gemma-3n-E4B-it-litert-preview"
21
+ # model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config)
22
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
23
 
24
+ self.model = pipeline("text-generation", model_id)
25
 
26
  @spaces.GPU
27
  def generate(self, message) -> str: