|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import atexit |
|
import logging |
|
import time |
|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available |
|
|
|
|
|
if is_requests_available(): |
|
import requests |
|
from requests import ConnectionError |
|
|
|
|
|
if is_vllm_available(): |
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator |
|
from vllm.distributed.utils import StatelessProcessGroup |
|
|
|
if is_vllm_ascend_available(): |
|
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class VLLMClient: |
|
""" |
|
A client class to interact with a vLLM server. |
|
|
|
This class provides methods to generate completions, initialize and manage weight update groups, and update model |
|
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. |
|
|
|
Args: |
|
host (`str`, *optional*, defaults to `"0.0.0.0"`): |
|
IP address of the vLLM server. |
|
server_port (`int`, *optional*, defaults to `8000`): |
|
Port number of the vLLM server. |
|
group_port (`int`, *optional*, defaults to `51216`): |
|
Port number for the weight update group. |
|
connection_timeout (`float`, *optional*, defaults to `0.0`): |
|
Total timeout duration in seconds to wait for the server to be up. If the server is not up after the |
|
timeout, a `ConnectionError` is raised. |
|
|
|
Examples: |
|
Run the vLLM server with the model `Qwen/Qwen2.5-7B`: |
|
|
|
``` |
|
$ trl vllm-serve --model Qwen/Qwen2.5-7B |
|
... |
|
INFO: Application startup complete. |
|
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) |
|
``` |
|
|
|
Use the client to generate completions and update model weights: |
|
|
|
```python |
|
>>> from trl.extras.vllm_client import VLLMClient |
|
>>> client = VLLMClient() |
|
>>> client.generate(["Hello, AI!", "Tell me a joke"]) |
|
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], |
|
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] |
|
|
|
>>> from transformers import AutoModelForCausalLM |
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") |
|
>>> client.init_communicator() |
|
>>> client.update_model_params(model) |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0 |
|
): |
|
if not is_requests_available(): |
|
raise ImportError("requests is not installed. Please install it with `pip install requests`.") |
|
if not is_vllm_available(): |
|
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") |
|
|
|
self.session = requests.Session() |
|
self.host = host |
|
self.server_port = server_port |
|
self.group_port = group_port |
|
self.check_server(connection_timeout) |
|
|
|
def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): |
|
""" |
|
Check server availability with retries on failure, within a total timeout duration. If the server is not up |
|
after the total timeout duration, raise a `ConnectionError`. |
|
|
|
Args: |
|
retry_interval (`float`, *optional*, defaults to `2.0`): |
|
Interval in seconds between retries. |
|
total_timeout (`float`, *optional*, defaults to `0.0`): |
|
Total timeout duration in seconds. |
|
""" |
|
url = f"http://{self.host}:{self.server_port}/health/" |
|
start_time = time.time() |
|
|
|
while True: |
|
try: |
|
response = requests.get(url) |
|
except requests.exceptions.RequestException as exc: |
|
|
|
elapsed_time = time.time() - start_time |
|
if elapsed_time >= total_timeout: |
|
raise ConnectionError( |
|
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} " |
|
"seconds. Make sure the server is running by running `trl vllm-serve`." |
|
) from exc |
|
else: |
|
if response.status_code == 200: |
|
logger.info("Server is up!") |
|
return None |
|
|
|
|
|
logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") |
|
time.sleep(retry_interval) |
|
|
|
def generate( |
|
self, |
|
prompts: list[str], |
|
n: int = 1, |
|
repetition_penalty: float = 1.0, |
|
temperature: float = 1.0, |
|
top_p: float = 1.0, |
|
top_k: int = -1, |
|
min_p: float = 0.0, |
|
max_tokens: int = 16, |
|
guided_decoding_regex: Optional[str] = None, |
|
) -> list[list[int]]: |
|
""" |
|
Generates model completions for the provided prompts. |
|
|
|
Args: |
|
prompts (`list[str]`): |
|
List of text prompts for which the model will generate completions. |
|
n (`int`, *optional*, defaults to `1`): |
|
Number of completions to generate for each prompt. |
|
repetition_penalty (`float`, *optional*, defaults to `1.0`): |
|
Parameter for repetition penalty. 1.0 means no penalty. |
|
temperature (`float`, *optional*, defaults to `1.0`): |
|
Temperature parameter for sampling. Higher values increase diversity. |
|
top_p (`float`, *optional*, defaults to `1.0`): |
|
Top-p sampling parameter.`1.0` means no truncation. |
|
top_k (`int`, *optional*, defaults to `-1`): |
|
Top-k sampling parameter. `-1` means no truncation. |
|
min_p (`float`, *optional*, defaults to `0.0`): |
|
Minimum probability for sampling. |
|
max_tokens (`int`, *optional*, defaults to `16`): |
|
Maximum number of tokens to generate for each prompt. |
|
guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): |
|
Regular expression to guide the decoding process. |
|
|
|
Returns: |
|
`list[list[int]]`: |
|
List of lists of token IDs representing the model-generated completions for each prompt. |
|
""" |
|
url = f"http://{self.host}:{self.server_port}/generate/" |
|
response = self.session.post( |
|
url, |
|
json={ |
|
"prompts": prompts, |
|
"n": n, |
|
"repetition_penalty": repetition_penalty, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"min_p": min_p, |
|
"max_tokens": max_tokens, |
|
"guided_decoding_regex": guided_decoding_regex, |
|
}, |
|
) |
|
if response.status_code == 200: |
|
return response.json()["completion_ids"] |
|
else: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
def init_communicator(self): |
|
""" |
|
Initializes the weight update group in a distributed setup for model synchronization. |
|
""" |
|
|
|
url = f"http://{self.host}:{self.server_port}/get_world_size/" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
vllm_world_size = response.json()["world_size"] |
|
else: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
world_size = vllm_world_size + 1 |
|
self.rank = vllm_world_size |
|
|
|
|
|
url = f"http://{self.host}:{self.server_port}/init_communicator/" |
|
|
|
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) |
|
if response.status_code != 200: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
|
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) |
|
self.pynccl_comm = PyNcclCommunicator(pg, device=0) |
|
|
|
|
|
atexit.register(self.close_communicator) |
|
|
|
def update_named_param(self, name: str, weights: torch.Tensor): |
|
""" |
|
Updates a specific named parameter in the model and broadcasts it to other processes. |
|
|
|
Args: |
|
name (`str`): |
|
Name of the layer whose weights are being updated. |
|
weights (`torch.Tensor`): |
|
Tensor containing the updated weights. |
|
""" |
|
dtype, shape = str(weights.dtype), tuple(weights.shape) |
|
url = f"http://{self.host}:{self.server_port}/update_named_param/" |
|
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) |
|
if response.status_code != 200: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
|
|
self.pynccl_comm.broadcast(weights, src=self.rank) |
|
self.pynccl_comm.group.barrier() |
|
|
|
def update_model_params(self, model: nn.Module): |
|
""" |
|
Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. |
|
|
|
Args: |
|
model (`nn.Module`): |
|
Model whose parameters (weights/biases) are to be updated. |
|
""" |
|
for name, param in model.named_parameters(): |
|
|
|
self.update_named_param(name, param.data) |
|
|
|
def reset_prefix_cache(self): |
|
""" |
|
Resets the prefix cache for the model. |
|
""" |
|
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/" |
|
response = self.session.post(url) |
|
if response.status_code != 200: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
def close_communicator(self): |
|
""" |
|
Closes the weight update group and cleans up the communication group. |
|
""" |
|
url = f"http://{self.host}:{self.server_port}/close_communicator/" |
|
|
|
try: |
|
response = self.session.post(url) |
|
except ConnectionError: |
|
|
|
pass |
|
else: |
|
if response.status_code != 200: |
|
raise Exception(f"Request failed: {response.status_code}, {response.text}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from vllm import SamplingParams |
|
|
|
client = VLLMClient() |
|
client.init_communicator() |
|
|
|
|
|
responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) |
|
print("Responses:", responses) |
|
|
|
|
|
from transformers import AutoModelForCausalLM |
|
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") |
|
client.update_model_params(model) |
|
|