olegshulyakov's picture
Update README.md
742cc76 verified
|
raw
history blame
2.39 kB
metadata
library_name: gemma_torch
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
  - pytorch
extra_gated_heading: Access Codeemma on Hugging Face
extra_gated_prompt: >-
  To access CodeGemma on Hugging Face, you’re required to review and agree to
  Google’s usage license. To do this, please ensure you’re logged-in to Hugging
  Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
base_model:
  - google/codegemma-1.1-2b

CodeGemma Model Card

This repository corresponds to the CodeGemma 2B checkpoint for use with Gemma PyTorch. If you're looking for the transformers implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-1.1-2b.

Model Page: CodeGemma

Resources and Technical Documentation:

Terms of Use: Terms

Authors: Google

Sample Usage

from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

VARIANT = "2b" 
MACHINE_TYPE = "cpu" 
weights_dir = 'codegemma-1.1-2b-pytorch' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'codegemma-1.1-{VARIANT}.pt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
    sys.exit(0)<|fim_middle|>"""

model.generate(
    FIM_PROMPT,
    device=device,
    output_len=100,
)