|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import os |
|
from collections.abc import Sequence |
|
from contextlib import asynccontextmanager |
|
from dataclasses import dataclass, field |
|
from itertools import chain |
|
from multiprocessing import Pipe, Process |
|
from multiprocessing.connection import Connection |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from trl import TrlParser |
|
from trl.import_utils import ( |
|
is_fastapi_available, |
|
is_pydantic_available, |
|
is_uvicorn_available, |
|
is_vllm_ascend_available, |
|
is_vllm_available, |
|
) |
|
|
|
|
|
if is_fastapi_available(): |
|
from fastapi import FastAPI |
|
|
|
|
|
if is_pydantic_available(): |
|
from pydantic import BaseModel |
|
|
|
|
|
if is_uvicorn_available(): |
|
import uvicorn |
|
|
|
|
|
if is_vllm_available(): |
|
from vllm import LLM, SamplingParams |
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator |
|
from vllm.distributed.parallel_state import get_world_group |
|
from vllm.distributed.utils import StatelessProcessGroup |
|
from vllm.sampling_params import GuidedDecodingParams |
|
from vllm.utils import get_open_port |
|
|
|
if is_vllm_ascend_available(): |
|
from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" |
|
|
|
|
|
class WeightSyncWorkerExtension: |
|
""" |
|
A vLLM worker extension that enables weight synchronization between a client and multiple server workers. |
|
|
|
This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle |
|
efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights |
|
from a client process and distribute them to all worker processes participating in model inference. |
|
""" |
|
|
|
|
|
pynccl_comm = None |
|
client_rank = None |
|
|
|
def init_communicator(self, host: str, port: int, world_size: int) -> None: |
|
""" |
|
Initializes the weight update communicator using a stateless process group. |
|
|
|
This method creates a `StatelessProcessGroup` that allows external training processes to |
|
communicate with vLLM workers without interfering with the global torch distributed group. |
|
|
|
Args: |
|
host (`str`): |
|
Hostname or IP address of the master node. |
|
port (`int`): |
|
Port number to be used for communication. |
|
world_size (`int`): |
|
Total number of participating processes in the update group. |
|
""" |
|
if self.pynccl_comm is not None: |
|
raise RuntimeError("Weight update group already initialized. Call close_communicator first.") |
|
|
|
|
|
rank = get_world_group().rank |
|
|
|
|
|
pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size) |
|
|
|
|
|
self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) |
|
|
|
|
|
self.client_rank = world_size - 1 |
|
|
|
def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None: |
|
""" |
|
Receives updated weights from the client process and updates the named parameter in the model. |
|
|
|
Args: |
|
name (`str`): |
|
Name of the weight tensor being updated. |
|
dtype (`torch.dtype`): |
|
Data type of the weight tensor (e.g., `torch.float32`). |
|
shape (`Sequence[int]`): |
|
Shape of the weight tensor. |
|
""" |
|
if self.pynccl_comm is None: |
|
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.") |
|
|
|
|
|
weight = torch.empty(shape, dtype=dtype, device=self.device) |
|
|
|
|
|
self.pynccl_comm.broadcast(weight, src=self.client_rank) |
|
self.pynccl_comm.group.barrier() |
|
|
|
|
|
self.model_runner.model.load_weights(weights=[(name, weight)]) |
|
|
|
def close_communicator(self) -> None: |
|
""" |
|
Closes the communicator when weight synchronization is no longer needed. |
|
|
|
This method deletes the NCCL communicator to release associated resources. |
|
""" |
|
|
|
if self.pynccl_comm is not None: |
|
del self.pynccl_comm |
|
self.pynccl_comm = None |
|
self.client_rank = None |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
r""" |
|
Arguments for the script. |
|
|
|
Args: |
|
model (`str`): |
|
Model name or path to load the model from. |
|
revision (`str` or `None`, *optional*, defaults to `None`): |
|
Revision to use for the model. If not specified, the default branch will be used. |
|
tensor_parallel_size (`int`, *optional*, defaults to `1`): |
|
Number of tensor parallel workers to use. |
|
data_parallel_size (`int`, *optional*, defaults to `1`): |
|
Number of data parallel workers to use. |
|
host (`str`, *optional*, defaults to `"0.0.0.0"`): |
|
Host address to run the server on. |
|
port (`int`, *optional*, defaults to `8000`): |
|
Port to run the server on. |
|
gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): |
|
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the |
|
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus |
|
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors |
|
during initialization. |
|
dtype (`str`, *optional*, defaults to `"auto"`): |
|
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined |
|
based on the model configuration. Find the supported values in the vLLM documentation. |
|
max_model_len (`int` or `None`, *optional*, defaults to `None`): |
|
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced |
|
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model |
|
context size, which might be much larger than the KV cache, leading to inefficiencies. |
|
enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`): |
|
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support |
|
this feature. |
|
enforce_eager (`bool` or `None`, *optional*, defaults to `None`): |
|
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the |
|
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. |
|
log_level (`str`, *optional*, defaults to `"info"`): |
|
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, |
|
`"trace"`. |
|
""" |
|
|
|
model: str = field( |
|
metadata={"help": "Model name or path to load the model from."}, |
|
) |
|
revision: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, |
|
) |
|
tensor_parallel_size: int = field( |
|
default=1, |
|
metadata={"help": "Number of tensor parallel workers to use."}, |
|
) |
|
data_parallel_size: int = field( |
|
default=1, |
|
metadata={"help": "Number of data parallel workers to use."}, |
|
) |
|
host: str = field( |
|
default="0.0.0.0", |
|
metadata={"help": "Host address to run the server on."}, |
|
) |
|
port: int = field( |
|
default=8000, |
|
metadata={"help": "Port to run the server on."}, |
|
) |
|
gpu_memory_utilization: float = field( |
|
default=0.9, |
|
metadata={ |
|
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " |
|
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " |
|
"size and thus improve the model's throughput. However, if the value is too high, it may cause " |
|
"out-of-memory (OOM) errors during initialization." |
|
}, |
|
) |
|
dtype: str = field( |
|
default="auto", |
|
metadata={ |
|
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " |
|
"determined based on the model configuration. Find the supported values in the vLLM documentation." |
|
}, |
|
) |
|
max_model_len: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " |
|
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " |
|
"context size, which might be much larger than the KV cache, leading to inefficiencies." |
|
}, |
|
) |
|
enable_prefix_caching: Optional[bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " |
|
"hardware support this feature." |
|
}, |
|
) |
|
enforce_eager: Optional[bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always " |
|
"execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager " |
|
"execution in hybrid." |
|
}, |
|
) |
|
log_level: str = field( |
|
default="info", |
|
metadata={ |
|
"help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', " |
|
"'trace'." |
|
}, |
|
) |
|
|
|
|
|
def llm_worker( |
|
script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection |
|
) -> None: |
|
|
|
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) |
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) |
|
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) |
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) |
|
|
|
llm = LLM( |
|
model=script_args.model, |
|
revision=script_args.revision, |
|
tensor_parallel_size=script_args.tensor_parallel_size, |
|
gpu_memory_utilization=script_args.gpu_memory_utilization, |
|
enforce_eager=script_args.enforce_eager, |
|
dtype=script_args.dtype, |
|
|
|
|
|
|
|
enable_prefix_caching=script_args.enable_prefix_caching, |
|
max_model_len=script_args.max_model_len, |
|
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", |
|
) |
|
|
|
|
|
connection.send({"status": "ready"}) |
|
|
|
while True: |
|
|
|
try: |
|
command = connection.recv() |
|
except KeyboardInterrupt: |
|
llm.collective_rpc(method="close_communicator") |
|
break |
|
|
|
|
|
if command["type"] in ["call", "fire_and_forget"]: |
|
method_name = command["method"] |
|
args, kwargs = command.get("args", ()), command.get("kwargs", {}) |
|
method = getattr(llm, method_name) |
|
result = method(*args, **kwargs) |
|
if command["type"] == "call": |
|
connection.send(result) |
|
elif command["type"] == "shutdown": |
|
break |
|
|
|
|
|
def chunk_list(lst: list, n: int) -> list[list]: |
|
""" |
|
Split list `lst` into `n` evenly distributed sublists. |
|
|
|
Example: |
|
>>> chunk_list([1, 2, 3, 4, 5, 6], 2) |
|
[[1, 2, 3], [4, 5, 6]] |
|
>>> chunk_list([1, 2, 3, 4, 5, 6], 4) |
|
[[1, 2], [3, 4], [5], [6]] |
|
>>> chunk_list([1, 2, 3, 4, 5, 6], 8) |
|
[[1], [2], [3], [4], [5], [6], [], []] |
|
""" |
|
k, r = divmod(len(lst), n) |
|
return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)] |
|
|
|
|
|
def main(script_args: ScriptArguments): |
|
if not is_fastapi_available(): |
|
raise ImportError( |
|
"FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`." |
|
) |
|
|
|
if not is_pydantic_available(): |
|
raise ImportError( |
|
"Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`." |
|
) |
|
|
|
if not is_uvicorn_available(): |
|
raise ImportError( |
|
"Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`." |
|
) |
|
|
|
if not is_vllm_available(): |
|
raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.") |
|
|
|
|
|
master_port = get_open_port() |
|
connections = [] |
|
processes = [] |
|
for data_parallel_rank in range(script_args.data_parallel_size): |
|
parent_connection, child_connection = Pipe() |
|
process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection)) |
|
process.start() |
|
connections.append(parent_connection) |
|
processes.append(process) |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
ready_connections = set() |
|
while len(ready_connections) < script_args.data_parallel_size: |
|
for connection in connections: |
|
msg = connection.recv() |
|
if isinstance(msg, dict) and msg.get("status") == "ready": |
|
ready_connections.add(connection) |
|
|
|
yield |
|
|
|
|
|
for process in processes: |
|
process.join(timeout=10) |
|
if process.is_alive(): |
|
logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...") |
|
process.terminate() |
|
process.join() |
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.get("/health/") |
|
async def health(): |
|
""" |
|
Health check endpoint to verify that the server is running. |
|
""" |
|
return {"status": "ok"} |
|
|
|
@app.get("/get_world_size/") |
|
async def get_world_size(): |
|
""" |
|
Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`. |
|
|
|
Returns: |
|
`dict`: |
|
A dictionary containing the world size. |
|
|
|
Example response: |
|
```json |
|
{"world_size": 8} |
|
``` |
|
""" |
|
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} |
|
|
|
class GenerateRequest(BaseModel): |
|
prompts: list[str] |
|
n: int = 1 |
|
repetition_penalty: float = 1.0 |
|
temperature: float = 1.0 |
|
top_p: float = 1.0 |
|
top_k: int = -1 |
|
min_p: float = 0.0 |
|
max_tokens: int = 16 |
|
guided_decoding_regex: Optional[str] = None |
|
|
|
class GenerateResponse(BaseModel): |
|
completion_ids: list[list[int]] |
|
|
|
@app.post("/generate/", response_model=GenerateResponse) |
|
async def generate(request: GenerateRequest): |
|
""" |
|
Generates completions for the provided prompts. |
|
|
|
Args: |
|
request (`GenerateRequest`): |
|
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. |
|
|
|
Returns: |
|
`GenerateResponse`: |
|
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. |
|
|
|
Example request: |
|
```json |
|
{"prompts": ["Hello world", "What is AI?"]} |
|
``` |
|
|
|
Example response: |
|
```json |
|
{"completion_ids": [[101, 102, 103], [201, 202, 203]]} |
|
``` |
|
""" |
|
|
|
|
|
if request.guided_decoding_regex is not None: |
|
guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex) |
|
else: |
|
guided_decoding = None |
|
|
|
|
|
sampling_params = SamplingParams( |
|
n=request.n, |
|
repetition_penalty=request.repetition_penalty, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
min_p=request.min_p, |
|
max_tokens=request.max_tokens, |
|
guided_decoding=guided_decoding, |
|
) |
|
|
|
chunked_prompts = chunk_list(request.prompts, script_args.data_parallel_size) |
|
|
|
|
|
for connection, prompts in zip(connections, chunked_prompts): |
|
|
|
|
|
|
|
if not prompts: |
|
prompts = ["<placeholder>"] |
|
kwargs = {"prompts": prompts, "sampling_params": sampling_params} |
|
connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) |
|
|
|
|
|
all_outputs = [connection.recv() for connection in connections] |
|
|
|
|
|
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts] |
|
|
|
|
|
all_outputs = list(chain.from_iterable(all_outputs)) |
|
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] |
|
return {"completion_ids": completion_ids} |
|
|
|
class InitCommunicatorRequest(BaseModel): |
|
host: str |
|
port: int |
|
world_size: int |
|
|
|
@app.post("/init_communicator/") |
|
async def init_communicator(request: InitCommunicatorRequest): |
|
""" |
|
Initializes the communicator for synchronizing model weights between a client and multiple server |
|
workers. |
|
|
|
Args: |
|
request (`InitCommunicatorRequest`): |
|
- `host` (`str`): Hostname or IP address of the master node. |
|
- `port` (`int`): Port number to be used for communication. |
|
- `world_size` (`int`): Total number of participating processes in the group. |
|
""" |
|
world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1 |
|
|
|
|
|
|
|
|
|
kwargs = {"method": "init_communicator", "args": (request.host, request.port, world_size)} |
|
for connection in connections: |
|
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) |
|
|
|
return {"message": "Request received, initializing communicator"} |
|
|
|
class UpdateWeightsRequest(BaseModel): |
|
name: str |
|
dtype: str |
|
shape: list[int] |
|
|
|
@app.post("/update_named_param/") |
|
async def update_named_param(request: UpdateWeightsRequest): |
|
""" |
|
Updates the model weights with the provided tensor. |
|
|
|
Once this endpoint is called, the client process should broadcast the updated weights to all server workers. |
|
|
|
Args: |
|
request (`UpdateWeightsRequest`): |
|
- `name` (`str`): Name of the weight tensor being updated. |
|
- `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`). |
|
- `shape` (list of `int`): Shape of the weight |
|
|
|
""" |
|
|
|
|
|
|
|
dtype = torch.__getattribute__(request.dtype.split(".")[-1]) |
|
kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))} |
|
for connection in connections: |
|
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) |
|
|
|
return {"message": "Request received, updating named parameter"} |
|
|
|
@app.post("/reset_prefix_cache/") |
|
async def reset_prefix_cache(): |
|
""" |
|
Resets the prefix cache for the model. |
|
""" |
|
for connection in connections: |
|
connection.send({"type": "call", "method": "reset_prefix_cache"}) |
|
|
|
all_outputs = [connection.recv() for connection in connections] |
|
success = all(output for output in all_outputs) |
|
return {"message": "Request received, resetting prefix cache status: " + str(success)} |
|
|
|
@app.post("/close_communicator/") |
|
async def close_communicator(): |
|
""" |
|
Closes the weight update group and cleans up associated resources. |
|
""" |
|
kwargs = {"method": "close_communicator"} |
|
for connection in connections: |
|
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) |
|
return {"message": "Request received, closing communicator"} |
|
|
|
|
|
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) |
|
|
|
|
|
def make_parser(subparsers: argparse._SubParsersAction = None): |
|
if subparsers is not None: |
|
parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments) |
|
else: |
|
parser = TrlParser(ScriptArguments) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = make_parser() |
|
(script_args,) = parser.parse_args_and_config() |
|
main(script_args) |
|
|