[Refactor]Refactor of vllm_ascend/distributed module (#5719)
### What this PR does / why we need it?
Based on the RFC:https://github.com/vllm-project/vllm-ascend/issues/5604
This PR is a refactoring of vllm_ascend/distributed, moving all
kv_transfer realtaed codes into a dedicated folder, which has already
been done in vLLM
### Does this PR introduce _any_ user-facing change?
NA
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
@@ -0,0 +1,203 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
PrefixCachingMetrics)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||
get_manager_for_kv_cache_spec
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class CPUCacheStats:
|
||||
|
||||
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.log_stats = log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||
self.time_sec = int(time.time())
|
||||
|
||||
def log(self):
|
||||
current_time_sec = int(time.time())
|
||||
# Log the prefix cache hit rate every 10 seconds.
|
||||
if current_time_sec - self.time_sec >= 10:
|
||||
self.time_sec = current_time_sec
|
||||
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def update(self, num_tokens, num_computed_tokens):
|
||||
# Note the function is called by scheduler
|
||||
if self.log_stats and self.enable_prefix_caching:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
|
||||
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.hits = num_computed_tokens
|
||||
self.prefix_cache_stats.queries = num_tokens
|
||||
self.prefix_cache_stats.requests = 1
|
||||
|
||||
|
||||
class CPUKVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
num_cpu_blocks: int,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
# Record kv block hashes, avoid redundant computation.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
# Record blocks touched in get_matched_num_and_touch().
|
||||
self.req_to_computed_blocks: defaultdict[
|
||||
str, list[KVCacheBlock]] = defaultdict(list)
|
||||
# Record the request that failed to allocate.
|
||||
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||
log_stats=True)
|
||||
# Record request that will be free after finish sending
|
||||
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||
|
||||
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (request.sampling_params.prompt_logprobs is not None):
|
||||
return 0, False
|
||||
request_id = request.request_id
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = request.block_hashes
|
||||
self.req_to_block_hashes[request_id] = block_hashes
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||
# We should touch these blocks in the concurrent scenarios.
|
||||
self.block_pool.touch(computed_blocks)
|
||||
|
||||
# cup prefix cache status set and log
|
||||
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||
num_computed_tokens)
|
||||
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||
self.cpu_cache_stats.prefix_cache_stats)
|
||||
self.cpu_cache_stats.log()
|
||||
|
||||
return num_computed_tokens, False
|
||||
|
||||
def _release_ahead_touch(self, request_id: str):
|
||||
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
if computed_blocks:
|
||||
self.single_type_manager.block_pool.free_blocks(
|
||||
reversed(computed_blocks))
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
|
||||
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||
for request_id in unallocated_req_ids:
|
||||
self._free_slots(request_id)
|
||||
req_to_new_blocks = {}
|
||||
for request_id, num_tokens in req_to_num_tokens.items():
|
||||
if self.req_failed_to_allocate[request_id]:
|
||||
continue
|
||||
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=num_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
))
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
self._release_ahead_touch(request_id)
|
||||
self.req_failed_to_allocate[request_id] = True
|
||||
continue
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request_id, new_computed_blocks)
|
||||
# Allocate new blocks but do not cache now.
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
request_id, num_tokens)
|
||||
self.req_to_num_tokens[request_id] = num_tokens
|
||||
# No need to release ref_cnt because we use officially.
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
req_to_new_blocks[request_id] = [
|
||||
block.block_id for block in new_computed_blocks + new_blocks
|
||||
]
|
||||
return req_to_new_blocks
|
||||
|
||||
def record_request_cache_and_free_slots(self, request: Request):
|
||||
logger.debug(
|
||||
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
self.req_to_free[request.request_id] = request
|
||||
|
||||
def cache_and_free_slots(self, request_id: str):
|
||||
logger.debug(
|
||||
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||
)
|
||||
if request_id not in self.req_to_free:
|
||||
logger.Error(
|
||||
f"request {request_id} not in req_to_free, maybe bug!")
|
||||
return
|
||||
request = self.req_to_free[request_id]
|
||||
if not self.req_failed_to_allocate[request_id]:
|
||||
self.single_type_manager.cache_blocks(
|
||||
request,
|
||||
self.req_to_num_tokens[request_id],
|
||||
)
|
||||
self._free_slots(request_id)
|
||||
logger.debug(
|
||||
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||
del self.req_to_free[request_id]
|
||||
|
||||
def _free_slots(self, request_id: str):
|
||||
# This function is designed to be reentrant.
|
||||
self._release_ahead_touch(request_id)
|
||||
self.single_type_manager.free(request_id)
|
||||
self.req_to_block_hashes.pop(request_id, None)
|
||||
self.req_to_computed_blocks.pop(request_id, None)
|
||||
self.req_failed_to_allocate.pop(request_id, None)
|
||||
self.req_to_num_tokens.pop(request_id, None)
|
||||
@@ -0,0 +1,528 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec)
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
gpu_block_ids: list[int]
|
||||
cpu_block_ids: list[int]
|
||||
num_scheduled_tokens: int
|
||||
num_computed_tokens: int
|
||||
num_gpu_computed_tokens: int
|
||||
num_cpu_computed_tokens: int
|
||||
|
||||
def update(self, other: "ReqMeta"):
|
||||
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||
self.num_computed_tokens = other.num_computed_tokens
|
||||
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
requests: dict[str, ReqMeta]
|
||||
finished_req_ids: set[str]
|
||||
|
||||
|
||||
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
if not vllm_config.cache_config.enable_prefix_caching:
|
||||
self.connector_scheduler: Optional[
|
||||
CPUOffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[
|
||||
CPUOffloadingConnectorWorker] = None
|
||||
elif role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||
vllm_config)
|
||||
self.connector_worker = None
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
if self.connector_worker is not None:
|
||||
assert isinstance(connector_metadata,
|
||||
CPUOffloadingConnectorMetadata)
|
||||
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.clear_connector_metadata()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.start_load_kv()
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
if self.connector_worker is not None:
|
||||
self.connector_worker.wait_for_layer_load()
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(), None
|
||||
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.update_state_after_alloc(request)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
if self.connector_scheduler is not None:
|
||||
return self.connector_scheduler.build_connector_meta(
|
||||
scheduler_output)
|
||||
return KVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self, request: "Request",
|
||||
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
if self.connector_scheduler is not None:
|
||||
self.connector_scheduler.request_finished(request)
|
||||
return True, None
|
||||
|
||||
|
||||
class CPUOffloadingConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorScheduler")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||
self.allocated_req_ids: set[str] = set()
|
||||
self.finished_req_ids: list[str] = []
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.zmq_rpc_client.call("post_init")
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"swap_in_threshold", 0)
|
||||
else:
|
||||
self.swap_in_threshold = 0
|
||||
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, ori_request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||
"get_matched_num_and_touch", request)
|
||||
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||
self.num_cpu_computed_tokens[
|
||||
request.request_id] = num_cpu_computed_tokens
|
||||
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||
else:
|
||||
return 0, load_async
|
||||
|
||||
def update_state_after_alloc(self, request: "Request"):
|
||||
self.allocated_req_ids.add(request.request_id)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
num_tokens = {}
|
||||
# process scheduled_new_reqs
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
num_tokens[req_id] = (
|
||||
req.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
# process scheduled_cached_reqs
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
num_tokens[req_id] = (
|
||||
cached_reqs.num_computed_tokens[idx] +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
|
||||
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||
self.allocated_req_ids -
|
||||
scheduler_output.num_scheduled_tokens.keys())
|
||||
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||
num_tokens,
|
||||
unallocated_req_ids)
|
||||
metadata = CPUOffloadingConnectorMetadata(
|
||||
requests={},
|
||||
finished_req_ids=set(self.finished_req_ids),
|
||||
)
|
||||
for req in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req.req_id
|
||||
gpu_block_ids = req.block_ids[0]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=req.num_computed_tokens,
|
||||
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||
|
||||
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||
metadata.requests[req_id] = ReqMeta(
|
||||
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||
num_scheduled_tokens=scheduler_output.
|
||||
num_scheduled_tokens[req_id],
|
||||
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||
self.num_gpu_computed_tokens.clear()
|
||||
self.num_cpu_computed_tokens.clear()
|
||||
self.allocated_req_ids.clear()
|
||||
self.finished_req_ids.clear()
|
||||
return metadata
|
||||
|
||||
def request_finished(self, ori_request: "Request"):
|
||||
request = copy.deepcopy(ori_request)
|
||||
request.get_hash_new_full_blocks = None
|
||||
self.finished_req_ids.append(request.request_id)
|
||||
# inform metadata server to record request, and free it after finish sending
|
||||
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||
request)
|
||||
|
||||
|
||||
class CPUOffloadingConnectorWorker:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
logger.info("init CPUOffloadingConnectorWorker")
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.pp_rank = get_pp_group().rank_in_group
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_rank = self.tp_group.rank_in_group
|
||||
self.tp_world_size = self.tp_group.world_size
|
||||
self.use_mla = vllm_config.model_config.use_mla
|
||||
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.load_stream = torch.npu.Stream()
|
||||
self.save_stream = torch.npu.Stream()
|
||||
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||
self.load_block_mapping: list[tuple[int, int]] = []
|
||||
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||
self.save_thread = threading.Thread(target=self._save_listener)
|
||||
self.save_thread.start()
|
||||
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||
|
||||
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||
config = VllmConfig()
|
||||
config.cache_config = vllm_config.cache_config
|
||||
config.parallel_config = vllm_config.parallel_config
|
||||
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.init_metadata_server(config)
|
||||
self._wait_for_metadata_process_start()
|
||||
|
||||
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||
self.metadata_thread = threading.Thread(
|
||||
target=MetadataServerProc.run_metadata_server,
|
||||
args=(vllm_config, ),
|
||||
)
|
||||
self.metadata_thread.daemon = True
|
||||
self.metadata_thread.start()
|
||||
|
||||
def _wait_for_metadata_process_start(self):
|
||||
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||
while True:
|
||||
try:
|
||||
if self.zmq_rpc_client.call("ready"):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.info(f"wait for metadata server to start, error: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def bind_connector_metadata(
|
||||
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||
for req_id, req in connector_metadata.requests.items():
|
||||
if req_id in self.requests:
|
||||
self.requests[req_id].update(req)
|
||||
req = self.requests[req_id]
|
||||
else:
|
||||
self.requests[req_id] = req
|
||||
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||
req.num_computed_tokens // self.block_size):
|
||||
self.load_block_mapping.append(
|
||||
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||
for req_id in connector_metadata.finished_req_ids:
|
||||
if req_id in self.requests:
|
||||
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self.load_block_mapping.clear()
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||
self.gpu_kv_caches = kv_caches
|
||||
model_config = self.vllm_config.model_config
|
||||
mla_config: Optional[MLAConfig] = None
|
||||
if model_config.use_mla:
|
||||
mla_config = MLAConfig(
|
||||
model_config.hf_text_config.kv_lora_rank,
|
||||
model_config.hf_text_config.qk_rope_head_dim)
|
||||
self.cpu_kv_caches = list(
|
||||
self.zmq_rpc_client.call(
|
||||
"init_cpu_kv_caches",
|
||||
self.pp_rank,
|
||||
self.tp_rank,
|
||||
get_kv_cache_spec(self.vllm_config),
|
||||
mla_config,
|
||||
).values())
|
||||
|
||||
def start_load_kv(self) -> None:
|
||||
self.current_layer = 0
|
||||
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||
self.load_kv_layer(0)
|
||||
|
||||
def wait_for_layer_load(self) -> None:
|
||||
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||
self.load_stream.synchronize()
|
||||
self.current_layer += 1
|
||||
self.load_kv_layer(self.current_layer)
|
||||
|
||||
def load_kv_layer(self, layer: int):
|
||||
if layer == len(self.gpu_kv_caches):
|
||||
return
|
||||
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||
with torch.npu.stream(self.load_stream):
|
||||
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||
for gpu_layer_part, cpu_layer_part in zip(
|
||||
gpu_kv_caches, cpu_kv_caches):
|
||||
gpu_layer_part[gpu_block_id].copy_(
|
||||
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||
|
||||
def get_finished(self) -> set[str]:
|
||||
done_sending: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
id = self.save_output_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
done_sending.add(id)
|
||||
for id in done_sending:
|
||||
del self.requests[id]
|
||||
if self.tp_world_size == 1:
|
||||
return done_sending
|
||||
if self.tp_rank == 0:
|
||||
for req_id in done_sending:
|
||||
self.done_sending_count[req_id] += 1
|
||||
other_ranks_finished_ids: list[str] = []
|
||||
for i in range(1, self.tp_world_size):
|
||||
other_ranks_finished_ids.extend(
|
||||
self.tp_group.recv_object(src=i))
|
||||
for req_id in other_ranks_finished_ids:
|
||||
self.done_sending_count[req_id] += 1
|
||||
all_done_sending: set[str] = set()
|
||||
for req_id in list(self.done_sending_count.keys()):
|
||||
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||
del self.done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
# release cpu_kv_cache after request sending finished
|
||||
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||
sending_finished_thread = threading.Thread(
|
||||
target=self._sending_finished, args=(all_done_sending, ))
|
||||
sending_finished_thread.daemon = True
|
||||
sending_finished_thread.start()
|
||||
|
||||
return all_done_sending
|
||||
else:
|
||||
self.tp_group.send_object(done_sending, dst=0)
|
||||
return done_sending
|
||||
|
||||
def _sending_finished(self, all_done_sending):
|
||||
for req_id in all_done_sending:
|
||||
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||
|
||||
def _save_listener(self):
|
||||
save_block_mapping = []
|
||||
while True:
|
||||
req_id, req = self.save_input_queue.get()
|
||||
for i in range(
|
||||
req.num_cpu_computed_tokens // self.block_size,
|
||||
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||
self.block_size, len(req.cpu_block_ids))):
|
||||
save_block_mapping.append(
|
||||
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||
with torch.npu.stream(self.save_stream):
|
||||
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||
if self.use_mla:
|
||||
start, step = self.tp_rank, self.tp_world_size
|
||||
else:
|
||||
start, step = 0, 1
|
||||
for i in range(start, len(save_block_mapping), step):
|
||||
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||
for cpu_layer_part, gpu_layer_part in zip(
|
||||
cpu_kv_caches, gpu_kv_caches):
|
||||
cpu_layer_part[cpu_block_id].copy_(
|
||||
gpu_layer_part[gpu_block_id],
|
||||
non_blocking=True)
|
||||
self.save_stream.synchronize()
|
||||
self.save_output_queue.put(req_id)
|
||||
save_block_mapping.clear()
|
||||
|
||||
|
||||
# copied and modified from vllm_ascend/worker/model_runner_v1.py
|
||||
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if vllm_config.cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
vllm_config.cache_config.cache_dtype]
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if use_mla and not use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||
else:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (vllm_config.speculative_config is not None
|
||||
and vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
Reference in New Issue
Block a user