jina-embeddings-v4 / custom_lora_module.py
jupyterjazz's picture
stateless-adapter-switching (#11)
e7f92e1 verified
from __future__ import annotations
import math
import warnings
from typing import Any, Optional, Union, List
import torch
import torch.nn as nn
from peft.tuners.lora import LoraLayer
class MultiAdapterLinear(nn.Module, LoraLayer):
"""
Custom LoRA module supporting multiple adapters for a linear layer.
This module extends the standard LoRA implementation to support multiple task-specific
adapters that can be dynamically selected during the forward pass. The task_label
parameter passed to the forward function determines which LoRA adapter(s) to use:
- If task_label is a string, all examples in the batch use the same adapter
- If task_label is a list of strings, each example can use a different adapter
This enables efficient multi-task inference where all task-specific LoRA adapters
are loaded in memory simultaneously and dynamically selected per example, eliminating
the need to switch adapter states between tasks and allowing optimal throughput
for mixed-task batches.
Derived from peft.tuners.lora.Linear.
"""
def __init__(
self,
base_layer,
adapter_name: str,
task_names: List[str],
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
is_target_conv_1d_layer: bool = False,
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
) -> None:
super().__init__()
LoraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
self.task_names = task_names
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def forward(self, x: torch.Tensor, task_label: Union[str, List[str]], *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
lora_A_keys = self.lora_A.keys()
for active_adapter in self.active_adapters:
if active_adapter not in lora_A_keys:
continue
if isinstance(task_label, str):
lora_A = self.lora_A[active_adapter][task_label]
lora_B = self.lora_B[active_adapter][task_label]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = self._cast_input_dtype(x, lora_A.weight.dtype)
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
unique_tasks = list(set(task_label))
lora_output = torch.zeros_like(result)
for task in unique_tasks:
task_indices = [i for i, t in enumerate(task_label) if t == task]
task_x = x[task_indices]
lora_A = self.lora_A[active_adapter][task]
lora_B = self.lora_B[active_adapter][task]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
task_x = self._cast_input_dtype(task_x, lora_A.weight.dtype)
task_lora_value = lora_B(lora_A(dropout(task_x))) * scaling
for i, idx in enumerate(task_indices):
lora_output[idx] = task_lora_value[i]
result = result + lora_output
result = result.to(torch_result_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "lora." + rep
def update_layer(
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora: bool = False,
lora_bias: bool = False,
):
# This code works for linear layers, override for other layer types
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
self.lora_A[adapter_name] = nn.ModuleDict({
task_name: nn.Linear(self.in_features, r, bias=False)
for task_name in self.task_names
})
self.lora_B[adapter_name] = nn.ModuleDict({
task_name: nn.Linear(r, self.out_features, bias=lora_bias)
for task_name in self.task_names
})
self.lora_bias[adapter_name] = lora_bias
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
self.reset_lora_parameters(adapter_name, init_lora_weights)
self._move_adapter_to_device_of_base_layer(adapter_name)
self.use_dora[adapter_name] = False
self.set_adapter(self.active_adapters)
def reset_lora_parameters(self, adapter_name, init_lora_weights):
if init_lora_weights is False:
return
if init_lora_weights is True:
# initialize A the same way as the default for nn.Linear and B to zero
# https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
for task_name in self.task_names:
nn.init.kaiming_uniform_(self.lora_A[adapter_name][task_name].weight, a=math.sqrt(5))
elif init_lora_weights.lower() == "gaussian":
for task_name in self.task_names:
nn.init.normal_(self.lora_A[adapter_name][task_name].weight, std=1 / self.r[adapter_name])
else:
raise ValueError(f"Unknown initialization {init_lora_weights=}")
for task_name in self.task_names:
nn.init.zeros_(self.lora_B[adapter_name][task_name].weight)
if self.lora_bias[adapter_name]:
for task_name in self.task_names:
nn.init.zeros_(self.lora_B[adapter_name][task_name].bias)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
"""
raise NotImplementedError("Merge operation is not supported")
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
raise NotImplementedError("Unmerge operation is not supported")