[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
280
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
280
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
prefill_url: str
|
||||
decode_url: str
|
||||
metadata_backend: Union[str, None]
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
prefill_url=config.get("prefill_url"),
|
||||
decode_url=config.get("decode_url"),
|
||||
metadata_backend=config.get("metadata_backend", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeTransferEngineConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
||||
|
||||
def __init__(self, kv_rank: int, local_rank: int):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.local_rank = local_rank
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and \
|
||||
base_prefill_port == base_decode_port:
|
||||
base_decode_port = str(int(base_decode_port) + 100)
|
||||
|
||||
prefill_port = int(base_prefill_port) + self.local_rank
|
||||
decode_port = int(base_decode_port) + self.local_rank
|
||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||
|
||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server, self.config.protocol,
|
||||
self.config.device_name, self.config.metadata_backend)
|
||||
|
||||
self.remote_url = (self.decode_url
|
||||
if kv_rank == 0 else self.prefill_url)
|
||||
|
||||
# Initialize ZeroMQ context and sockets
|
||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
||||
|
||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||
decode_host, base_decode_port)
|
||||
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||
d_host: str, d_port: str) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
protocol: str, device_name: str,
|
||||
metadata_backend: Union[str, None]) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if metadata_backend is None:
|
||||
self.engine.initialize(local_hostname, metadata_server, protocol,
|
||||
device_name)
|
||||
else:
|
||||
supported_backend = ["etcd", "redis"]
|
||||
metadata_backend = metadata_backend.lower()
|
||||
if metadata_backend not in supported_backend:
|
||||
raise ValueError(
|
||||
"Mooncake Configuration error. `metadata_backend`"
|
||||
f" should be one of {supported_backend}.")
|
||||
|
||||
self.engine.initialize_ext(local_hostname, metadata_server,
|
||||
protocol, device_name, metadata_backend)
|
||||
|
||||
def allocate_managed_buffer(self, length: int) -> int:
|
||||
"""Allocate a managed buffer of the specified length."""
|
||||
ret = self.engine.allocate_managed_buffer(length)
|
||||
if ret <= 0:
|
||||
logger.error("Allocation Return Error")
|
||||
raise Exception("Allocation Return Error")
|
||||
return ret
|
||||
|
||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
||||
"""Free a previously allocated managed buffer."""
|
||||
return self.engine.free_managed_buffer(buffer, length)
|
||||
|
||||
def transfer_sync(self, buffer: int, peer_buffer_address: int,
|
||||
length: int) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
ret = self.engine.transfer_sync_read(self.remote_url, buffer,
|
||||
peer_buffer_address, length)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
return ret
|
||||
|
||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
|
||||
length: int) -> int:
|
||||
"""Write bytes to the allocated buffer."""
|
||||
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
|
||||
|
||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
||||
"""Read bytes from the allocated buffer."""
|
||||
return self.engine.read_bytes_from_buffer(buffer, length)
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv()
|
||||
if ack != b'ACK':
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
self.free_managed_buffer(src_ptr, length)
|
||||
|
||||
def send_bytes(self, user_data: bytes) -> None:
|
||||
"""Send bytes to the remote process."""
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_multipart(
|
||||
[struct.pack("!Q", src_ptr),
|
||||
struct.pack("!Q", length)])
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
data = self.receiver_socket.recv_multipart()
|
||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||
length = struct.unpack("!Q", data[1])[0]
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send(b'ACK')
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class MooncakePipe(KVPipeBase):
|
||||
"""MooncakeTransferEngine based Pipe implementation."""
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None):
|
||||
"""Initialize the mooncake pipe and set related parameters."""
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
|
||||
self.local_rank)
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""Select available device (CUDA or CPU)."""
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
return hash(tensor.data_ptr())
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
||||
"""Implement the tensor sending logic using safetensors."""
|
||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor:
|
||||
"""Implement the tensor receiving logic using safetensors."""
|
||||
data = self.transfer_engine.recv_bytes()
|
||||
return safetensors_load(data)["tensor"].to(self.device)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""Send tensor to the target process."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = tensor if tensor is not None else self.none_tensor
|
||||
assert (len(tensor.shape) > 0)
|
||||
self.transport_thread.submit(self._send_impl, tensor)
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""Receive tensor from other processes."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
||||
return None
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup logic when closing the pipe."""
|
||||
self.transfer_engine.sender_socket.close()
|
||||
self.transfer_engine.receiver_socket.close()
|
||||
self.transfer_engine.sender_ack.close()
|
||||
self.transfer_engine.receiver_ack.close()
|
||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
||||
Reference in New Issue
Block a user