Jeff Myers II commited on
Commit
ef17b91
·
1 Parent(s): f24da04

Completed Prototype

Browse files
Files changed (1) hide show
  1. Gemma.py +23 -20
Gemma.py CHANGED
@@ -1,4 +1,5 @@
1
- from transformers import AutoTokenizer, Gemma3ForCausalLM
 
2
  from huggingface_hub import login
3
  import spaces
4
  import torch
@@ -14,30 +15,32 @@ class GemmaLLM:
14
 
15
  model_id = "google/gemma-3-1b-it"
16
 
17
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
18
- self.model = Gemma3ForCausalLM.from_pretrained(
19
- model_id,
20
- device_map="cuda" if torch.cuda.is_available() else "cpu",
21
- torch_dtype=torch.float16,
22
- ).eval()
 
 
23
 
24
- self.model = self.model.bfloat16()
25
-
26
  @spaces.GPU
27
  def generate(self, message) -> str:
28
- inputs = self.tokenizer.apply_chat_template(
29
- message,
30
- add_generation_prompt=True,
31
- tokenize=True,
32
- return_dict=True,
33
- return_tensors="pt",
34
- ).to(self.model.device)
35
 
36
- input_length = inputs["input_ids"].shape[1]
37
 
38
- with torch.inference_mode():
39
- outputs = self.model.generate(**inputs, max_new_tokens=1024)[0][input_length:]
40
- outputs = self.tokenizer.decode(outputs, skip_special_tokens=True)
 
 
41
 
42
  return outputs
43
 
 
1
+ # from transformers import AutoTokenizer, Gemma3ForCausalLM
2
+ from transformers import pipeline
3
  from huggingface_hub import login
4
  import spaces
5
  import torch
 
15
 
16
  model_id = "google/gemma-3-1b-it"
17
 
18
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ # self.model = Gemma3ForCausalLM.from_pretrained(
20
+ # model_id,
21
+ # device_map="cuda" if torch.cuda.is_available() else "cpu",
22
+ # torch_dtype=torch.float16,
23
+ # ).eval()
24
+
25
+ self.model = pipeline("text-generation", model=model_id, torch_dtype=torch.bfloat16, device="cuda")
26
 
 
 
27
  @spaces.GPU
28
  def generate(self, message) -> str:
29
+ # inputs = self.tokenizer.apply_chat_template(
30
+ # message,
31
+ # add_generation_prompt=True,
32
+ # tokenize=True,
33
+ # return_dict=True,
34
+ # return_tensors="pt",
35
+ # ).to(self.model.device)
36
 
37
+ # input_length = inputs["input_ids"].shape[1]
38
 
39
+ # with torch.inference_mode():
40
+ # outputs = self.model.generate(**inputs, max_new_tokens=1024)[0][input_length:]
41
+ # outputs = self.tokenizer.decode(outputs, skip_special_tokens=True)
42
+
43
+ outputs = self.model(message, max_new_tokens=1024)[0]["generated_text"]
44
 
45
  return outputs
46