|
''' |
|
@File : ImageReward.py |
|
@Time : 2023/01/28 19:53:00 |
|
@Auther : Jiazheng Xu |
|
@Contact : xjz22@mails.tsinghua.edu.cn |
|
@Description: ImageReward Reward model. |
|
* Based on CLIP code base and improved-aesthetic-predictor code base |
|
* https://github.com/openai/CLIP |
|
* https://github.com/christophschuhmann/improved-aesthetic-predictor |
|
''' |
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from .models.BLIP.blip_pretrain import BLIP_Pretrain |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
|
|
try: |
|
from torchvision.transforms import InterpolationMode |
|
BICUBIC = InterpolationMode.BICUBIC |
|
except ImportError: |
|
BICUBIC = Image.BICUBIC |
|
|
|
|
|
def _convert_image_to_rgb(image): |
|
return image.convert("RGB") |
|
|
|
|
|
def _transform(n_px): |
|
return Compose([ |
|
Resize(n_px, interpolation=BICUBIC), |
|
CenterCrop(n_px), |
|
_convert_image_to_rgb, |
|
ToTensor(), |
|
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_size): |
|
super().__init__() |
|
self.input_size = input_size |
|
|
|
self.layers = nn.Sequential( |
|
nn.Linear(self.input_size, 1024), |
|
|
|
nn.Dropout(0.2), |
|
nn.Linear(1024, 128), |
|
|
|
nn.Dropout(0.2), |
|
nn.Linear(128, 64), |
|
|
|
nn.Dropout(0.1), |
|
nn.Linear(64, 16), |
|
|
|
nn.Linear(16, 1) |
|
) |
|
|
|
|
|
for name, param in self.layers.named_parameters(): |
|
if 'weight' in name: |
|
nn.init.normal_(param, mean=0.0, std=1.0/(self.input_size+1)) |
|
if 'bias' in name: |
|
nn.init.constant_(param, val=0) |
|
|
|
def forward(self, input): |
|
return self.layers(input) |
|
|
|
|
|
class ImageReward(nn.Module): |
|
def __init__(self, med_config, device='cpu'): |
|
super().__init__() |
|
self.device = device |
|
|
|
self.blip = BLIP_Pretrain(image_size=224, vit='large', med_config=med_config) |
|
self.preprocess = _transform(224) |
|
self.mlp = MLP(768) |
|
|
|
self.mean = 0.16717362830052426 |
|
self.std = 1.0333394966054072 |
|
|
|
|
|
def score_gard(self, prompt_ids, prompt_attention_mask, image): |
|
|
|
image_embeds = self.blip.visual_encoder(image) |
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
|
text_output = self.blip.text_encoder(prompt_ids, |
|
attention_mask = prompt_attention_mask, |
|
encoder_hidden_states = image_embeds, |
|
encoder_attention_mask = image_atts, |
|
return_dict = True, |
|
) |
|
|
|
txt_features = text_output.last_hidden_state[:,0,:] |
|
rewards = self.mlp(txt_features) |
|
rewards = (rewards - self.mean) / self.std |
|
|
|
return rewards |
|
|
|
|
|
def score(self, prompt, image): |
|
|
|
if (type(image).__name__=='list'): |
|
_, rewards = self.inference_rank(prompt, image) |
|
return rewards |
|
|
|
|
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
pil_image = image |
|
elif isinstance(image, str): |
|
if os.path.isfile(image): |
|
pil_image = Image.open(image) |
|
else: |
|
raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') |
|
|
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device) |
|
image_embeds = self.blip.visual_encoder(image) |
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
|
text_output = self.blip.text_encoder(text_input.input_ids, |
|
attention_mask = text_input.attention_mask, |
|
encoder_hidden_states = image_embeds, |
|
encoder_attention_mask = image_atts, |
|
return_dict = True, |
|
) |
|
|
|
txt_features = text_output.last_hidden_state[:,0,:].float() |
|
rewards = self.mlp(txt_features) |
|
rewards = (rewards - self.mean) / self.std |
|
|
|
return rewards.detach().cpu().numpy().item() |
|
|
|
|
|
def inference_rank(self, prompt, generations_list): |
|
|
|
text_input = self.blip.tokenizer(prompt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(self.device) |
|
|
|
txt_set = [] |
|
for generation in generations_list: |
|
|
|
if isinstance(generation, Image.Image): |
|
pil_image = generation |
|
elif isinstance(generation, str): |
|
if os.path.isfile(generation): |
|
pil_image = Image.open(generation) |
|
else: |
|
raise TypeError(r'This image parameter type has not been supportted yet. Please pass PIL.Image or file path str.') |
|
image = self.preprocess(pil_image).unsqueeze(0).to(self.device) |
|
image_embeds = self.blip.visual_encoder(image) |
|
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(self.device) |
|
text_output = self.blip.text_encoder(text_input.input_ids, |
|
attention_mask = text_input.attention_mask, |
|
encoder_hidden_states = image_embeds, |
|
encoder_attention_mask = image_atts, |
|
return_dict = True, |
|
) |
|
txt_set.append(text_output.last_hidden_state[:,0,:]) |
|
|
|
txt_features = torch.cat(txt_set, 0).float() |
|
rewards = self.mlp(txt_features) |
|
rewards = (rewards - self.mean) / self.std |
|
rewards = torch.squeeze(rewards) |
|
_, rank = torch.sort(rewards, dim=0, descending=True) |
|
_, indices = torch.sort(rank, dim=0) |
|
indices = indices + 1 |
|
|
|
return indices.detach().cpu().numpy().tolist(), rewards.detach().cpu().numpy().tolist() |