[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
77
vllm/distributed/kv_transfer/kv_connector_agent.py
Normal file
77
vllm/distributed/kv_transfer/kv_connector_agent.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A centralized entrypoint to perform distributed KV cache transfer.
|
||||
|
||||
This implementation is a shim wrapper on two APIs exposed by `kv_connector`:
|
||||
1. `send_kv_caches_and_hidden_states`
|
||||
2. `recv_kv_caches_and_hidden_states
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class KVTransferAgent:
|
||||
"""
|
||||
A class designated for distributed KV transfer
|
||||
|
||||
Target use cases:
|
||||
1. Disaggregated prefill
|
||||
2. Remote KV cache storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
):
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.kv_transfer_config is None:
|
||||
raise ValueError("KVTransferConfig is not set in the VllmConfig,"
|
||||
" cannot initialize KVConnector.")
|
||||
|
||||
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
|
||||
"TransferAgent should only be used when kv_connector is set."
|
||||
|
||||
self.connector = KVConnectorFactory.create_connector_v0(
|
||||
rank, local_rank, config)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
self.connector.send_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches,
|
||||
hidden_or_intermediate_states)
|
||||
|
||||
def close(self) -> None:
|
||||
self.connector.close()
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: list[torch.Tensor]
|
||||
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
return self.connector.recv_kv_caches_and_hidden_states(
|
||||
model_executable, model_input, kv_caches)
|
||||
Reference in New Issue
Block a user