# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Example usage: accelerate launch \ --config_file=deepspeed_zero2.yaml \ sft_video_llm.py \ --dataset_name=mfarre/simplevideoshorts \ --video_cache_dir="/optional/path/to/cache/" \ --model_name_or_path=Qwen/Qwen2-VL-7B-Instruct \ --per_device_train_batch_size=1 \ --output_dir=video-llm-output \ --bf16=True \ --tf32=True \ --gradient_accumulation_steps=4 \ --num_train_epochs=4 \ --optim="adamw_torch_fused" \ --logging_steps=1 \ --log_level="debug" \ --log_level_replica="debug" \ --save_strategy="steps" \ --save_steps=300 \ --learning_rate=8e-5 \ --max_grad_norm=0.3 \ --warmup_ratio=0.1 \ --lr_scheduler_type="cosine" \ --report_to="wandb" \ --push_to_hub=False \ --torch_dtype=bfloat16 \ --gradient_checkpointing=True """ import json import os import random from dataclasses import dataclass, field from typing import Any import requests import torch import wandb from datasets import load_dataset from peft import LoraConfig from qwen_vl_utils import process_vision_info from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map def download_video(url: str, cache_dir: str) -> str: """Download video if not already present locally.""" os.makedirs(cache_dir, exist_ok=True) # Create cache dir if it doesn't exist filename = url.split("/")[-1] local_path = os.path.join(cache_dir, filename) if os.path.exists(local_path): return local_path try: with requests.get(url, stream=True) as r: r.raise_for_status() with open(local_path, "wb") as f: for chunk in r.iter_content(chunk_size=8192): if chunk: f.write(chunk) return local_path except requests.RequestException as e: raise Exception(f"Failed to download video: {e}") from e def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]: """Prepare dataset example for training.""" video_url = example["video_url"] timecoded_cc = example["timecoded_cc"] qa_pairs = json.loads(example["qa"]) system_message = "You are an expert in movie narrative analysis." base_prompt = f"""Analyze the video and consider the following timecoded subtitles: {timecoded_cc} Based on this information, please answer the following questions:""" selected_qa = random.sample(qa_pairs, 1)[0] messages = [ {"role": "system", "content": [{"type": "text", "text": system_message}]}, { "role": "user", "content": [ {"type": "video", "video": download_video(video_url, cache_dir), "max_pixels": 360 * 420, "fps": 1.0}, {"type": "text", "text": f"{base_prompt}\n\nQuestion: {selected_qa['question']}"}, ], }, {"role": "assistant", "content": [{"type": "text", "text": selected_qa["answer"]}]}, ] return {"messages": messages} def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: """Collate batch of examples for training.""" texts = [] video_inputs = [] for i, example in enumerate(examples): try: video_path = next( content["video"] for message in example["messages"] for content in message["content"] if content.get("type") == "video" ) print(f"Processing video: {os.path.basename(video_path)}") texts.append(processor.apply_chat_template(example["messages"], tokenize=False)) video_input = process_vision_info(example["messages"])[1][0] video_inputs.append(video_input) except Exception as e: raise ValueError(f"Failed to process example {i}: {e}") from e inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True) labels = inputs["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # Handle visual tokens based on processor type visual_tokens = ( [151652, 151653, 151656] if isinstance(processor, Qwen2VLProcessor) else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] ) for visual_token_id in visual_tokens: labels[labels == visual_token_id] = -100 inputs["labels"] = labels return inputs @dataclass class CustomScriptArguments(ScriptArguments): r""" Arguments for the script. Args: video_cache_dir (`str`, *optional*, defaults to `"/tmp/videos/"`): Video cache directory. """ video_cache_dir: str = field(default="/tmp/videos/", metadata={"help": "Video cache directory."}) if __name__ == "__main__": # Parse arguments parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() # Configure training args training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} # Load dataset dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") # Setup model torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) # Quantization configuration for 4-bit training bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # Model initialization model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype, device_map=get_kbit_device_map(), quantization_config=bnb_config, ) model = AutoModelForVision2Seq.from_pretrained(model_args.model_name_or_path, **model_kwargs) peft_config = LoraConfig( task_type="CAUSAL_LM", r=16, lora_alpha=16, lora_dropout=0.1, bias="none", target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) # Configure model modules for gradients if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() model.config.use_reentrant = False model.enable_input_require_grads() processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) # Prepare dataset prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset] # Initialize wandb if specified if training_args.report_to == "wandb": wandb.init(project="video-llm-training") # Initialize trainer trainer = SFTTrainer( model=model, args=training_args, train_dataset=prepared_dataset, data_collator=collate_fn, peft_config=peft_config, tokenizer=processor.tokenizer, ) # Train model trainer.train() # Save final model trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if trainer.accelerator.is_main_process: processor.push_to_hub(training_args.hub_model_id) # Cleanup del model del trainer torch.cuda.empty_cache() wandb.finish()