# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Permission is hereby granted, free of charge, to any person obtaining a copy of this software # and associated documentation files (the "Software"), to deal in the Software without # restriction, including without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the # Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all copies or # substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR # OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR # OTHER DEALINGS IN THE SOFTWARE. from __future__ import annotations import torch from torch import nn from torch.nn import functional as F from transformers import ( PreTrainedModel, BertModel, AutoTokenizer, ) import os from transformers.modeling_outputs import SequenceClassifierOutput from typing import Union, List, Optional from collections import defaultdict import numpy as np import math from huggingface_hub import hf_hub_download from .configuration_listconranker import ListConRankerConfig class QueryEmbedding(nn.Module): def __init__(self, config) -> None: super().__init__() self.query_embedding = nn.Embedding(2, config.list_con_hidden_size) self.layerNorm = nn.LayerNorm(config.list_con_hidden_size) def forward(self, x, tags): query_embeddings = self.query_embedding(tags) x += query_embeddings x = self.layerNorm(x) return x class ListTransformer(nn.Module): def __init__(self, num_layer, config) -> None: super().__init__() self.config = config self.list_transformer_layer = nn.TransformerEncoderLayer( config.list_con_hidden_size, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False, ) self.list_transformer = nn.TransformerEncoder( self.list_transformer_layer, num_layer ) self.relu = nn.ReLU() self.query_embedding = QueryEmbedding(config) self.linear_score3 = nn.Linear( config.list_con_hidden_size * 2, config.list_con_hidden_size ) self.linear_score2 = nn.Linear( config.list_con_hidden_size * 2, config.list_con_hidden_size ) self.linear_score1 = nn.Linear(config.list_con_hidden_size * 2, 1) def forward( self, pair_features: torch.Tensor, pair_nums: List[int] ) -> torch.Tensor: batch_pair_features = pair_features.split(pair_nums) pair_feature_query_passage_concat_list = [] for i in range(len(batch_pair_features)): pair_feature_query = ( batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1) ) pair_feature_passage = batch_pair_features[i][1:] pair_feature_query_passage_concat_list.append( torch.cat([pair_feature_query, pair_feature_passage], dim=1) ) pair_feature_query_passage_concat = torch.cat( pair_feature_query_passage_concat_list, dim=0 ) batch_pair_features = nn.utils.rnn.pad_sequence( batch_pair_features, batch_first=True ) query_embedding_tags = torch.zeros( batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device, ) query_embedding_tags[:, 0] = 1 batch_pair_features = self.query_embedding( batch_pair_features, query_embedding_tags ) mask = self.generate_attention_mask(pair_nums) query_mask = self.generate_attention_mask_custom(pair_nums) pair_list_features = self.list_transformer( batch_pair_features, src_key_padding_mask=mask, mask=query_mask ) output_pair_list_features = [] output_query_list_features = [] pair_features_after_transformer_list = [] for idx, pair_num in enumerate(pair_nums): output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :]) output_query_list_features.append(pair_list_features[idx, 0, :]) pair_features_after_transformer_list.append( pair_list_features[idx, :pair_num, :] ) pair_features_after_transformer_cat_query_list = [] for idx, pair_num in enumerate(pair_nums): query_ft = ( output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1) ) pair_features_after_transformer_cat_query = torch.cat( [query_ft, output_pair_list_features[idx]], dim=1 ) pair_features_after_transformer_cat_query_list.append( pair_features_after_transformer_cat_query ) pair_features_after_transformer_cat_query = torch.cat( pair_features_after_transformer_cat_query_list, dim=0 ) pair_feature_query_passage_concat = self.relu( self.linear_score2(pair_feature_query_passage_concat) ) pair_features_after_transformer_cat_query = self.relu( self.linear_score3(pair_features_after_transformer_cat_query) ) final_ft = torch.cat( [ pair_feature_query_passage_concat, pair_features_after_transformer_cat_query, ], dim=1, ) logits = self.linear_score1(final_ft).squeeze() return logits, torch.cat(pair_features_after_transformer_list, dim=0) def generate_attention_mask(self, pair_num): max_len = max(pair_num) batch_size = len(pair_num) mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device) for i, length in enumerate(pair_num): mask[i, length:] = True return mask def generate_attention_mask_custom(self, pair_num): max_len = max(pair_num) mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device) mask[0, 1:] = True return mask class ListConRankerModel(PreTrainedModel): """ ListConRanker model for sequence classification that's compatible with AutoModelForSequenceClassification. """ config_class = ListConRankerConfig base_model_prefix = "listconranker" def __init__(self, config: ListConRankerConfig): super().__init__(config) self.config = config self.num_labels = config.num_labels self.hf_model = BertModel(config.bert_config) self.sigmoid = nn.Sigmoid() self.linear_in_embedding = nn.Linear( config.hidden_size, config.list_con_hidden_size ) self.list_transformer = ListTransformer( config.list_transformer_layers, config, ) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: if self.training: raise NotImplementedError("Training not supported; use eval mode.") device = input_ids.device self.list_transformer.device = device # Reorganize by unique queries and their passages ( reorganized_input_ids, reorganized_attention_mask, reorganized_token_type_ids, pair_nums, group_indices, ) = self._reorganize_inputs(input_ids, attention_mask, token_type_ids) out = self.hf_model( input_ids=reorganized_input_ids, attention_mask=reorganized_attention_mask, token_type_ids=reorganized_token_type_ids, return_dict=True, ) feats = out.last_hidden_state pooled = self.average_pooling(feats, reorganized_attention_mask) embedded = self.linear_in_embedding(pooled) logits, _ = self.list_transformer(embedded, pair_nums) probs = self.sigmoid(logits) # Restore original order sorted_probs = self._restore_original_order(probs, group_indices) sorted_logits = self._restore_original_order(logits, group_indices) if not return_dict: return (sorted_probs, sorted_logits) return SequenceClassifierOutput( loss=None, logits=sorted_logits, hidden_states=out.hidden_states, attentions=out.attentions, ) def _reorganize_inputs( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor], ) -> tuple[ torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int], List[List[int]] ]: """ Group inputs by unique queries: for each query, produce [query] + its passages, then flatten, pad, and return pair sizes and original indices mapping. """ batch_size = input_ids.size(0) # Structure: query_key -> { # 'query': (seq, mask, tt), # 'passages': [(seq, mask, tt), ...], # 'indices': [original_index, ...] # } grouped = {} for idx in range(batch_size): seq = input_ids[idx] mask = attention_mask[idx] token_type_ids[idx] if token_type_ids is not None else torch.zeros_like(seq) sep_idxs = (seq == self.config.sep_token_id).nonzero(as_tuple=True)[0] if sep_idxs.numel() == 0: raise ValueError(f"No SEP in sequence {idx}") first_sep = sep_idxs[0].item() second_sep = sep_idxs[1].item() # Extract query and passage q_seq = seq[: first_sep + 1] q_mask = mask[: first_sep + 1] q_tt = torch.zeros_like(q_seq) p_seq = seq[first_sep : second_sep + 1] p_mask = mask[first_sep : second_sep + 1] p_seq = p_seq.clone() p_seq[0] = self.config.cls_token_id p_tt = torch.zeros_like(p_seq) # Build key excluding CLS/SEP key = tuple( q_seq[ (q_seq != self.config.cls_token_id) & (q_seq != self.config.sep_token_id) ].tolist() ) # truncation q_seq = q_seq[: self.config.max_position_embeddings] q_seq[-1] = self.config.sep_token_id p_seq = p_seq[: self.config.max_position_embeddings] p_seq[-1] = self.config.sep_token_id q_mask = q_mask[: self.config.max_position_embeddings] p_mask = p_mask[: self.config.max_position_embeddings] q_tt = q_tt[: self.config.max_position_embeddings] p_tt = p_tt[: self.config.max_position_embeddings] if key not in grouped: grouped[key] = { "query": (q_seq, q_mask, q_tt), "passages": [], "indices": [], } grouped[key]["passages"].append((p_seq, p_mask, p_tt)) grouped[key]["indices"].append(idx) # Flatten according to group insertion order seqs, masks, tts, pair_nums, group_indices = [], [], [], [], [] for key, data in grouped.items(): q_seq, q_mask, q_tt = data["query"] passages = data["passages"] indices = data["indices"] # record sizes and original positions pair_nums.append(len(passages) + 1) # +1 for the query group_indices.append(indices) # append query then its passages seqs.append(q_seq) masks.append(q_mask) tts.append(q_tt) for p_seq, p_mask, p_tt in passages: seqs.append(p_seq) masks.append(p_mask) tts.append(p_tt) # Pad to uniform length max_len = max(s.size(0) for s in seqs) padded_seqs, padded_masks, padded_tts = [], [], [] for s, m, t in zip(seqs, masks, tts): ps = torch.zeros(max_len, dtype=s.dtype, device=s.device) pm = torch.zeros(max_len, dtype=m.dtype, device=m.device) pt = torch.zeros(max_len, dtype=t.dtype, device=t.device) ps[: s.size(0)] = s pm[: m.size(0)] = m pt[: t.size(0)] = t padded_seqs.append(ps) padded_masks.append(pm) padded_tts.append(pt) rid = torch.stack(padded_seqs) ram = torch.stack(padded_masks) rtt = torch.stack(padded_tts) if token_type_ids is not None else None return rid, ram, rtt, pair_nums, group_indices def _restore_original_order( self, logits: torch.Tensor, group_indices: List[List[int]], ) -> torch.Tensor: """ Map flattened logits back so each original index gets its passage score. """ out = torch.zeros(logits.size(0), dtype=logits.dtype, device=logits.device) i = 0 for indices in group_indices: for idx in indices: out[idx] = logits[i] i += 1 return out.reshape(-1, 1) def average_pooling(self, hidden_state, attention_mask): extended_attention_mask = ( attention_mask.unsqueeze(-1) .expand(hidden_state.size()) .to(dtype=hidden_state.dtype) ) masked_hidden_state = hidden_state * extended_attention_mask sum_embeddings = torch.sum(masked_hidden_state, dim=1) sum_mask = extended_attention_mask.sum(dim=1) return sum_embeddings / sum_mask @classmethod def from_pretrained( cls, model_name_or_path, config: Optional[ListConRankerConfig] = None, **kwargs ): model = super().from_pretrained(model_name_or_path, config=config, **kwargs) model.hf_model = BertModel.from_pretrained( model_name_or_path, config=model.config.bert_config, **kwargs ) linear_path = hf_hub_download( repo_id = model_name_or_path, filename = "linear_in_embedding.pt", revision = "main", cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None ) list_transformer_path = hf_hub_download( repo_id = "ByteDance/ListConRanker", filename = "list_transformer.pt", revision = "main", cache_dir = kwargs['cache_dir'] if 'cache_dir' in kwargs else None ) try: model.linear_in_embedding.load_state_dict(torch.load(linear_path)) model.list_transformer.load_state_dict(torch.load(list_transformer_path)) except FileNotFoundError as e: raise e return model def multi_passage( self, sentences: List[List[str]], batch_size: int = 32, tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( "ByteDance/ListConRanker" ), ): """ Process multiple passages for each query. :param sentences: List of lists, where each inner list contains sentences for a query. :return: Tensor of logits for each passage. """ pairs = [] for batch in sentences: if len(batch) < 2: raise ValueError("Each query must have at least one passage.") query = batch[0] passages = batch[1:] for passage in passages: pairs.append((query, passage)) total_batches = (len(pairs) + batch_size - 1) // batch_size total_logits = torch.zeros(len(pairs), dtype=torch.float, device=self.device) for batch in range(total_batches): batch_pairs = pairs[batch * batch_size : (batch + 1) * batch_size] inputs = tokenizer( batch_pairs, padding=True, truncation=False, return_tensors="pt", ) for k, v in inputs.items(): inputs[k] = v.to(self.device) logits = self(**inputs)[0] total_logits[batch * batch_size : (batch + 1) * batch_size] = ( logits.squeeze(1) ) return total_logits.tolist() def multi_passage_in_iterative_inference( self, sentences: List[str], stop_num: int = 20, decrement_rate: float = 0.2, min_filter_num: int = 10, tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained( "ByteDance/ListConRanker" ), ): """ Process multiple passages for one query in iterative inference. :param sentences: List contains sentences for a query. :return: Tensor of logits for each passage. """ if stop_num < 1: raise ValueError("stop_num must be greater than 0") if decrement_rate <= 0 or decrement_rate >= 1: raise ValueError("decrement_rate must be in (0, 1)") if min_filter_num < 1: raise ValueError("min_filter_num must be greater than 0") query = sentences[0] passage = sentences[1:] filter_times = 0 passage2score = defaultdict(list) while len(passage) > stop_num: batch = [[query] + passage] pred_scores = self.multi_passage( batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer ) pred_scores_argsort = np.argsort( pred_scores ).tolist() # Sort in increasing order passage_len = len(passage) to_filter_num = math.ceil(passage_len * decrement_rate) if to_filter_num < min_filter_num: to_filter_num = min_filter_num have_filter_num = 0 while have_filter_num < to_filter_num: idx = pred_scores_argsort[have_filter_num] passage2score[passage[idx]].append(pred_scores[idx] + filter_times) have_filter_num += 1 while ( pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]] ): idx = pred_scores_argsort[have_filter_num] passage2score[passage[idx]].append(pred_scores[idx] + filter_times) have_filter_num += 1 next_passage = [] next_passage_idx = have_filter_num while next_passage_idx < len(passage): idx = pred_scores_argsort[next_passage_idx] next_passage.append(passage[idx]) next_passage_idx += 1 passage = next_passage filter_times += 1 batch = [[query] + passage] pred_scores = self.multi_passage( batch, batch_size=len(batch[0]) - 1, tokenizer=tokenizer ) cnt = 0 while cnt < len(passage): passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times) cnt += 1 passage = sentences[1:] final_score = [] for i in range(len(passage)): p = passage[i] final_score.append(passage2score[p][0]) return final_score