# handler.py # Hugging Face Inference Endpoint custom handler for Mongolian GPT-2 summarization # Input JSON: # { # "inputs": "ARTICLE TEXT ...", # "parameters": { # "max_new_tokens": 160, # "num_beams": 4, # "do_sample": false, # "no_repeat_ngram_size": 3, # "length_penalty": 1.0, # "temperature": 1.0, # "top_p": 1.0, # "top_k": 50, # "return_full_text": false # } # } # Output JSON: # { "summary_text": "...", "used_new_tokens": 152, "requested_new_tokens": 160 } from typing import Any, Dict, List, Union import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Mongolian instruction + prompt template used during training INSTRUCTION = "Дараах бичвэрийг хураангуйлж бич." PROMPT_TEMPLATE = ( "### Даалгавар:\n" f"{INSTRUCTION}\n\n" "### Бичвэр:\n{article}\n\n" "### Хураангуй:\n" ) def _select_dtype() -> torch.dtype: if torch.cuda.is_available(): # Prefer bf16 if supported; otherwise use fp16 return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 return torch.float32 class EndpointHandler: """ Custom handler for HF Inference Endpoints: - __init__(path): loads model assets from `path` - __call__(data): performs generation given {"inputs": ..., "parameters": {...}} """ def __init__(self, path: str = ""): # Device & dtype self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = _select_dtype() # Load tokenizer/model from the repository directory self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) # Decoder-only model requires left padding for correct generation self.tokenizer.padding_side = "left" if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=self.dtype, ).to(self.device) # Safer attention path on many endpoint stacks self.model.config.attn_implementation = "eager" self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.eos_token_id = self.tokenizer.eos_token_id self.model.eval() # Read max context from config (GPT-2 default is 1024) self.max_context = getattr(self.model.config, "max_position_embeddings", 1024) def _build_prompt(self, article: str) -> str: return PROMPT_TEMPLATE.format(article=article.strip()) def _prepare_inputs( self, articles: List[str], requested_new: int ): """ Tokenize prompts so that prompt_len + max_new_tokens <= max_context. We first clamp requested_new, then tokenize with truncation=max_context - requested_new. """ # Basic safety clamps requested_new = int(max(1, min(requested_new, 512))) max_len_for_prompt = max(1, self.max_context - requested_new) prompts = [self._build_prompt(a) for a in articles] enc = self.tokenizer( prompts, add_special_tokens=False, truncation=True, max_length=max_len_for_prompt, return_tensors="pt", padding=True, # uses left padding because tokenizer.padding_side="left" ) enc = {k: v.to(self.device) for k, v in enc.items()} # Compute per-example available space and adjust new tokens if needed input_lens = enc["attention_mask"].sum(dim=1).tolist() per_example_new = [] for L in input_lens: available = max(0, self.max_context - int(L)) per_example_new.append(max(1, min(requested_new, available))) return enc, per_example_new, prompts @torch.no_grad() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Accept either {"inputs": "..."} or {"inputs": ["...", "..."]} raw_inputs: Union[str, List[str], Dict[str, Any]] = data.get("inputs", "") params: Dict[str, Any] = data.get("parameters", {}) or {} # Default generation hyperparameters (aligned with training) req_new = int(params.get("max_new_tokens", 160)) num_beams = int(params.get("num_beams", 4)) do_sample = bool(params.get("do_sample", False)) no_repeat = int(params.get("no_repeat_ngram_size", 3)) length_penalty = float(params.get("length_penalty", 1.0)) temperature = float(params.get("temperature", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", 50)) return_full_text = bool(params.get("return_full_text", False)) # Normalize inputs to a list of strings if isinstance(raw_inputs, str): articles = [raw_inputs] elif isinstance(raw_inputs, list): if not all(isinstance(x, str) for x in raw_inputs): raise ValueError("All elements of 'inputs' must be strings.") articles = raw_inputs else: # Accept {"article": "..."} as a courtesy maybe_article = data.get("article") if isinstance(maybe_article, str): articles = [maybe_article] else: raise ValueError("Expect 'inputs' as a string or list of strings.") # Tokenize prompts and cap new tokens per example enc, per_example_new, prompts = self._prepare_inputs(articles, req_new) # Generate (batched) gen_out = self.model.generate( **enc, max_new_tokens=max(per_example_new), # upper bound; actual stopping still respects EOS num_beams=num_beams, do_sample=do_sample, no_repeat_ngram_size=no_repeat, length_penalty=length_penalty, temperature=temperature, top_p=top_p, top_k=top_k, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, early_stopping=True, ) # Decode and postprocess per-item (cut after the prompt if needed) decoded = self.tokenizer.batch_decode(gen_out, skip_special_tokens=True) results = [] for i, text in enumerate(decoded): if return_full_text: full = text.strip() # Try to extract summary part for convenience too split_key = "### Хураангуй:\n" summary = full.split(split_key, 1)[-1].strip() if split_key in full else full else: # Remove the prompt prefix, return only the generated summary prefix = prompts[i] if text.startswith(prefix): summary = text[len(prefix):].strip() else: # Fallback split on the marker split_key = "### Хураангуй:\n" summary = text.split(split_key, 1)[-1].strip() if split_key in text else text.strip() full = None results.append({ "summary_text": summary, "used_new_tokens": per_example_new[i], "requested_new_tokens": req_new, **({"full_text": full} if return_full_text else {}) }) # If the input was a single string, return a single object if isinstance(raw_inputs, str): return results[0] return {"results": results}