[Refactor]Refactor of vllm_ascend/distributed module (#5910)
### What this PR does / why we need it?
Based on the RFC:https://github.com/vllm-project/vllm-ascend/issues/5604
This PR is a refactoring of vllm_ascend/distributed.
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec)
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.metadata import (
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
253
vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py
Normal file
253
vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from ucm.integration.vllm.ucm_connector import UCMConnector
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# isort: off
|
||||
if TYPE_CHECKING:
|
||||
if vllm_version_is('0.13.0'):
|
||||
from vllm.attention.backends.abstract import AttentionMetadata # type: ignore
|
||||
else:
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
# isort: on
|
||||
|
||||
|
||||
class UCMConnectorV1(KVConnectorBase_V1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
|
||||
ImplCls = UCMConnector
|
||||
self._ucm_engine = ImplCls(vllm_config, role)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
Args:
|
||||
kv_caches: A dictionary mapping layer names to KV cache tensors.
|
||||
"""
|
||||
self._ucm_engine.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._ucm_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._ucm_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata,
|
||||
**kwargs)
|
||||
|
||||
def wait_for_save(self) -> None:
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._ucm_engine.wait_for_save()
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
"""Clear the connector metadata.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
after the model execution.
|
||||
"""
|
||||
self._ucm_engine.clear_connector_metadata()
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
"""Set the connector metadata from the scheduler.
|
||||
|
||||
This function should be called by the model runner every time
|
||||
before the model execution. The metadata will be used for runtime
|
||||
KV cache loading and saving.
|
||||
|
||||
Args:
|
||||
connector_metadata (dict): the connector metadata.
|
||||
"""
|
||||
self._ucm_engine.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
return self._ucm_engine.get_block_ids_with_load_errors()
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._ucm_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int) -> None:
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._ucm_engine.update_state_after_alloc(request, blocks,
|
||||
num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._ucm_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._ucm_engine.request_finished(request, block_ids)
|
||||
|
||||
# ==============================
|
||||
# Metrics & Stats
|
||||
# ==============================
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: dict[str, Any] | None = None
|
||||
) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return UCMConnector.build_kv_connector_stats(data)
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[object]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
|
||||
This implementation forwards the call to the underlying
|
||||
UCMConnector engine.
|
||||
"""
|
||||
return UCMConnector.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
Reference in New Issue
Block a user