metadata
library_name: gemma_torch
license: gemma
license_link: https://ai.google.dev/gemma/terms
pipeline_tag: text-generation
tags:
- pytorch
extra_gated_heading: Access Codeemma on Hugging Face
extra_gated_prompt: >-
To access CodeGemma on Hugging Face, you’re required to review and agree to
Google’s usage license. To do this, please ensure you’re logged-in to Hugging
Face and click below. Requests are processed immediately.
extra_gated_button_content: Acknowledge license
base_model:
- google/codegemma-1.1-2b
CodeGemma Model Card
This repository corresponds to the CodeGemma 2B checkpoint for use with Gemma PyTorch. If you're looking for the
transformers
implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-1.1-2b.
Model Page: CodeGemma
Resources and Technical Documentation:
Terms of Use: Terms
Authors: Google
Sample Usage
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
VARIANT = "2b"
MACHINE_TYPE = "cpu"
weights_dir = 'codegemma-1.1-2b-pytorch'
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config)
ckpt_path = os.path.join(weights_dir, f'codegemma-1.1-{VARIANT}.pt')
model.load_weights(ckpt_path)
model = model.to(device).eval()
FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
sys.exit(0)<|fim_middle|>"""
model.generate(
FIM_PROMPT,
device=device,
output_len=100,
)