[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This file defines an interface `KVPipeBase`
that provides an abstraction for sending and receiving tensors, or None, via
distributed communications.
All classes instantiated from this interface are assumed to be a FIFO pipe.
If your distributed communication platform already supports key-value lookup,
you can bypass this interface and directly start from `kv_lookup_buffer`.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class KVPipeBase(ABC):
"""
This class provides an interface for sending and receiving tensors, or
None, by distributed communications.
"""
@abstractmethod
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
TODO: add a `key` argument so that we can use traditional
key-value database as the distributed communication mechanism behind
the pipe.
Args:
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:
Optional[torch.Tensor]: The tensor received from the pipeline. Can
be None.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the pipeline and release resources.
This method is responsible for closing the communication pipeline
and releasing any resources associated with it.
Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union
import torch
import zmq
from safetensors.torch import load as safetensors_load
from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
logger = init_logger(__name__)
NONE_INT = -150886311
@dataclass
class MooncakeTransferEngineConfig:
prefill_url: str
decode_url: str
metadata_backend: Union[str, None]
metadata_server: str
protocol: str
device_name: str
@staticmethod
def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeTransferEngineConfig(
prefill_url=config.get("prefill_url"),
decode_url=config.get("decode_url"),
metadata_backend=config.get("metadata_backend", None),
metadata_server=config.get("metadata_server"),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
)
@staticmethod
def load_from_env() -> 'MooncakeTransferEngineConfig':
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeTransferEngineConfig.from_file(config_file_path)
class MooncakeTransferEngine:
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
def __init__(self, kv_rank: int, local_rank: int):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
self.engine = TransferEngine()
self.local_rank = local_rank
try:
self.config = MooncakeTransferEngineConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
decode_host, base_decode_port = self.config.decode_url.split(':')
# Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and \
base_prefill_port == base_decode_port:
base_decode_port = str(int(base_decode_port) + 100)
prefill_port = int(base_prefill_port) + self.local_rank
decode_port = int(base_decode_port) + self.local_rank
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
self.decode_url = ':'.join([decode_host, str(decode_port)])
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server, self.config.protocol,
self.config.device_name, self.config.metadata_backend)
self.remote_url = (self.decode_url
if kv_rank == 0 else self.prefill_url)
# Initialize ZeroMQ context and sockets
self.context = zmq.Context() # type: ignore[attr-defined]
self.sender_socket = self.context.socket(zmq.constants.PUSH)
self.receiver_socket = self.context.socket(zmq.constants.PULL)
self.sender_ack = self.context.socket(zmq.constants.PULL)
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
decode_host, base_decode_port)
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
d_host: str, d_port: str) -> None:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str,
protocol: str, device_name: str,
metadata_backend: Union[str, None]) -> None:
"""Initialize the mooncake instance."""
if metadata_backend is None:
self.engine.initialize(local_hostname, metadata_server, protocol,
device_name)
else:
supported_backend = ["etcd", "redis"]
metadata_backend = metadata_backend.lower()
if metadata_backend not in supported_backend:
raise ValueError(
"Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}.")
self.engine.initialize_ext(local_hostname, metadata_server,
protocol, device_name, metadata_backend)
def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length."""
ret = self.engine.allocate_managed_buffer(length)
if ret <= 0:
logger.error("Allocation Return Error")
raise Exception("Allocation Return Error")
return ret
def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer."""
return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int,
length: int) -> int:
"""Synchronously transfer data to the specified address."""
ret = self.engine.transfer_sync_read(self.remote_url, buffer,
peer_buffer_address, length)
if ret < 0:
logger.error("Transfer Return Error")
raise Exception("Transfer Return Error")
return ret
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
length: int) -> int:
"""Write bytes to the allocated buffer."""
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer."""
return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv()
if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver")
self.free_managed_buffer(src_ptr, length)
def send_bytes(self, user_data: bytes) -> None:
"""Send bytes to the remote process."""
length = len(user_data)
src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process."""
data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup
self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length)
return ret
class MooncakePipe(KVPipeBase):
"""MooncakeTransferEngine based Pipe implementation."""
def __init__(self,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None):
"""Initialize the mooncake pipe and set related parameters."""
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
self.local_rank)
self.transport_thread: Optional[ThreadPoolExecutor] = None
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
def _select_device(self, device: str) -> torch.device:
"""Select available device (CUDA or CPU)."""
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def tensor_hash(self, tensor: torch.Tensor) -> int:
"""Calculate the hash value of the tensor."""
return hash(tensor.data_ptr())
def _send_impl(self, tensor: torch.Tensor) -> None:
"""Implement the tensor sending logic using safetensors."""
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
def _recv_impl(self) -> torch.Tensor:
"""Implement the tensor receiving logic using safetensors."""
data = self.transfer_engine.recv_bytes()
return safetensors_load(data)["tensor"].to(self.device)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""Send tensor to the target process."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = tensor if tensor is not None else self.none_tensor
assert (len(tensor.shape) > 0)
self.transport_thread.submit(self._send_impl, tensor)
def recv_tensor(self) -> Optional[torch.Tensor]:
"""Receive tensor from other processes."""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
tensor = self.transport_thread.submit(self._recv_impl).result()
if tensor.numel() == 1 and tensor.item() == NONE_INT:
return None
else:
return tensor
def close(self) -> None:
"""Cleanup logic when closing the pipe."""
self.transfer_engine.sender_socket.close()
self.transfer_engine.receiver_socket.close()
self.transfer_engine.sender_ack.close()
self.transfer_engine.receiver_ack.close()
self.transfer_engine.context.term() # Terminate the ZMQ context
logger.info("Closed the transfer engine and cleaned up resources.")

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
- Supports sending and receiving tensors with metadata
- Handles both CUDA and CPU device communications
- Implements a non-blocking tensor transfer mechanism
- Manages buffer size and provides backpressure control
- Supports distributed process groups with configurable parameters
"""
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Optional
import torch
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
logger = init_logger(__name__)
class BrokenPipeException(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Metadata = dict[str, Optional[torch.Tensor]]
class PyNcclPipe(KVPipeBase):
METADATA_LENGTH = 16
MAX_TENSOR_DIMENSIONS = 14
METADATA_DTYPE = torch.int64
def __init__(self,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
self.kv_parallel_size = self.config.kv_parallel_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
store_timeout=store_timeout,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
self.buffer_size = 0
self.buffer_size_lock = threading.Lock()
self.buffer_size_thresh = self.config.kv_buffer_size
def _get_device_send_recv_impl(
self, group: StatelessProcessGroup
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
[torch.Tensor, int], None]]:
send: Callable[[torch.Tensor, int], None]
recv: Callable[[torch.Tensor, int], None]
if self.device.type == "cuda":
# use PyNCCL for send / recv
comm = PyNcclCommunicator(group, device=self.local_rank)
comm.disabled = False
send, recv = comm.send, comm.recv # type: ignore
else:
# This send / recv implementation here is NOT intended to transfer
# KV caches (and should NOT be repurposed to transfer KV caches).
# Currently it is only used to transmit control-plane messages
# for PyNcclBuffer.
send = group.send_obj
def my_recv(x, src):
x[...] = group.recv_obj(src)
recv = my_recv
return send, recv
def _select_device(self, device: str):
logger.info("Selecting device: %s", device)
if device == "cuda":
return torch.device(f"cuda:{self.local_rank}")
else:
return torch.device("cpu")
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
"""
Create the metadata as a dictionary based on the input tensor.
Args:
tensor: The input tensor or None if no tensor is provided.
Returns:
metadata: A dictionary with the following keys:
- "dtype": The data type of the tensor or None.
- "shape": The shape of the tensor or None.
"""
if tensor is None:
return {"dtype": None, "shape": None}
else:
return {"dtype": tensor.dtype, "shape": tensor.shape}
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
"""
Create a buffer to receive the tensor based on the provided metadata.
Args:
metadata: A dictionary with keys "dtype" and "shape",
describing the tensor's data type and shape.
Returns:
buffer: A tensor of the specified type and shape,
allocated on `self.device`.
"""
return torch.empty(metadata["shape"],
dtype=metadata["dtype"],
device=self.device)
def _send_metadata(self, metadata: Metadata):
"""
Send the metadata dictionary to the target rank.
Args:
metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
def _recv_metadata(self) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
Returns:
metadata: A dictionary with keys "dtype" and "shape"
describing the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
Args:
tensor: The input tensor to be sent, or `None` if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
self.target_rank_for_send)
def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
buffer: The received tensor, or `None` if no tensor is received.
"""
metadata = self._recv_metadata()
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
tensor_size: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
except Exception as e:
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
torch.distributed.get_rank(), str(tensor), str(e))
import traceback
traceback.print_exc()
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Args:
tensor: The tensor to send, or `None` if no tensor is being sent.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
if tensor is not None:
tensor_size = tensor.element_size() * tensor.numel()
else:
tensor_size = 0
self.block_if_full()
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size)
def recv_tensor(self) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Args:
tensor: The received tensor, or `None` if no tensor is received.
"""
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return tensor
def close(self):
"""
Close the pipe and release associated resources.
"""
if hasattr(self,
"transport_thread") and self.transport_thread is not None:
self.transport_thread.shutdown()