# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ python examples/scripts/ddpo.py \ --num_epochs=200 \ --train_gradient_accumulation_steps=1 \ --sample_num_steps=50 \ --sample_batch_size=6 \ --train_batch_size=3 \ --sample_num_batches_per_epoch=4 \ --per_prompt_stat_tracking=True \ --per_prompt_stat_tracking_buffer_size=32 \ --tracker_project_name="stable_diffusion_training" \ --log_with="wandb" """ import os from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError from transformers import CLIPModel, CLIPProcessor, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline @dataclass class ScriptArguments: r""" Arguments for the script. Args: pretrained_model (`str`, *optional*, defaults to `"runwayml/stable-diffusion-v1-5"`): Pretrained model to use. pretrained_revision (`str`, *optional*, defaults to `"main"`): Pretrained model revision to use. hf_hub_model_id (`str`, *optional*, defaults to `"ddpo-finetuned-stable-diffusion"`): HuggingFace repo to save model weights to. hf_hub_aesthetic_model_id (`str`, *optional*, defaults to `"trl-lib/ddpo-aesthetic-predictor"`): Hugging Face model ID for aesthetic scorer model weights. hf_hub_aesthetic_model_filename (`str`, *optional*, defaults to `"aesthetic-model.pth"`): Hugging Face model filename for aesthetic scorer model weights. use_lora (`bool`, *optional*, defaults to `True`): Whether to use LoRA. """ pretrained_model: str = field( default="runwayml/stable-diffusion-v1-5", metadata={"help": "Pretrained model to use."} ) pretrained_revision: str = field(default="main", metadata={"help": "Pretrained model revision to use."}) hf_hub_model_id: str = field( default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to."} ) hf_hub_aesthetic_model_id: str = field( default="trl-lib/ddpo-aesthetic-predictor", metadata={"help": "Hugging Face model ID for aesthetic scorer model weights."}, ) hf_hub_aesthetic_model_filename: str = field( default="aesthetic-model.pth", metadata={"help": "Hugging Face model filename for aesthetic scorer model weights."}, ) use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) class MLP(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Linear(768, 1024), nn.Dropout(0.2), nn.Linear(1024, 128), nn.Dropout(0.2), nn.Linear(128, 64), nn.Dropout(0.1), nn.Linear(64, 16), nn.Linear(16, 1), ) @torch.no_grad() def forward(self, embed): return self.layers(embed) class AestheticScorer(torch.nn.Module): """ This model attempts to predict the aesthetic score of an image. The aesthetic score is a numerical approximation of how much a specific image is liked by humans on average. This is from https://github.com/christophschuhmann/improved-aesthetic-predictor """ def __init__(self, *, dtype, model_id, model_filename): super().__init__() self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.mlp = MLP() try: cached_path = hf_hub_download(model_id, model_filename) except EntryNotFoundError: cached_path = os.path.join(model_id, model_filename) state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True) self.mlp.load_state_dict(state_dict) self.dtype = dtype self.eval() @torch.no_grad() def __call__(self, images): device = next(self.parameters()).device inputs = self.processor(images=images, return_tensors="pt") inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} embed = self.clip.get_image_features(**inputs) # normalize embedding embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) return self.mlp(embed).squeeze(1) def aesthetic_scorer(hub_model_id, model_filename): scorer = AestheticScorer( model_id=hub_model_id, model_filename=model_filename, dtype=torch.float32, ) if is_torch_npu_available(): scorer = scorer.npu() elif is_torch_xpu_available(): scorer = scorer.xpu() else: scorer = scorer.cuda() def _fn(images, prompts, metadata): images = (images * 255).round().clamp(0, 255).to(torch.uint8) scores = scorer(images) return scores, {} return _fn # list of example prompts to feed stable diffusion animals = [ "cat", "dog", "horse", "monkey", "rabbit", "zebra", "spider", "bird", "sheep", "deer", "cow", "goat", "lion", "frog", "chicken", "duck", "goose", "bee", "pig", "turkey", "fly", "llama", "camel", "bat", "gorilla", "hedgehog", "kangaroo", ] def prompt_fn(): return np.random.choice(animals), {} def image_outputs_logger(image_data, global_step, accelerate_logger): # For the sake of this example, we will only log the last batch of images # and associated data result = {} images, prompts, _, rewards, _ = image_data[-1] for i, image in enumerate(images): prompt = prompts[i] reward = rewards[i].item() result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() accelerate_logger.log_images( result, step=global_step, ) if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, DDPOConfig)) script_args, training_args = parser.parse_args_into_dataclasses() training_args.project_kwargs = { "logging_dir": "./logs", "automatic_checkpoint_naming": True, "total_limit": 5, "project_dir": "./save", } pipeline = DefaultDDPOStableDiffusionPipeline( script_args.pretrained_model, pretrained_model_revision=script_args.pretrained_revision, use_lora=script_args.use_lora, ) trainer = DDPOTrainer( training_args, aesthetic_scorer(script_args.hf_hub_aesthetic_model_id, script_args.hf_hub_aesthetic_model_filename), prompt_fn, pipeline, image_samples_hook=image_outputs_logger, ) trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name)