[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #5) (#5996)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
|
`.../distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py`
|
|
`vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py`
|
| `
.../distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py`
|
| `
vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py`
|
| `
.../distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py`
|
| ` vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py` |
| `
vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py` |
| ` vllm_ascend/distributed/kv_transfer/utils/utils.py` |
| ` vllm_ascend/kv_offload/cpu_npu.py` |
| ` vllm_ascend/kv_offload/npu.py` |
| ` vllm_ascend/lora/lora_ops.py` |
| ` vllm_ascend/lora/punica_npu.py` |
| ` vllm_ascend/lora/utils.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-01-24 22:45:38 +08:00
committed by GitHub
parent 7faa6878a6
commit 6ccccad102
21 changed files with 866 additions and 1034 deletions

View File

@@ -1,20 +1,17 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.logger import logger
from vllm.utils.hashing import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics)
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
@@ -27,10 +24,9 @@ class CPUCacheStats:
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
@@ -57,7 +53,6 @@ class CPUCacheStats:
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
@@ -70,30 +65,26 @@ class CPUKVCacheManager:
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
if request.sampling_params.prompt_logprobs is not None:
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
@@ -119,10 +110,8 @@ class CPUKVCacheManager:
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
@@ -130,12 +119,10 @@ class CPUKVCacheManager:
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
@@ -143,44 +130,34 @@ class CPUKVCacheManager:
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager")
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager")
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
logger.Error(f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
@@ -189,8 +166,7 @@ class CPUKVCacheManager:
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):

View File

@@ -5,15 +5,15 @@ import queue
import threading
import time
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
@@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
MetadataServer,
MetadataServerProc,
MLAConfig,
)
if TYPE_CHECKING:
from vllm.v1.attention.backend import AttentionMetadata #type: ignore
from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None):
def __init__(
self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None
):
self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set())
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None
self.connector_worker: CPUOffloadingConnectorWorker | None = None
elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler(
vllm_config)
self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config)
self.connector_worker = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
@@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None:
assert isinstance(connector_metadata,
CPUOffloadingConnectorMetadata)
assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
@@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
if self.connector_worker is not None:
self.connector_worker.start_load_kv()
@@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1):
if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
def save_kv_layer(
self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs
) -> None:
pass
def wait_for_save(self):
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(), None
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens)
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int):
if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta(
scheduler_output)
return self.connector_scheduler.build_connector_meta(scheduler_output)
return KVConnectorMetadata()
def request_finished(
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]:
if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request)
return True, None
class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config
@@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler:
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
"swap_in_threshold", 0)
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0)
else:
self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens(
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
"get_matched_num_and_touch", request)
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[
request.request_id] = num_cpu_computed_tokens
self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async
else:
@@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler:
def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {}
# process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
num_tokens[req_id] = (
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]
# process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = (
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id]
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
self.allocated_req_ids -
scheduler_output.num_scheduled_tokens.keys())
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
num_tokens,
unallocated_req_ids)
unallocated_req_ids = set(
self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys()
)
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata(
requests={},
finished_req_ids=set(self.finished_req_ids),
@@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler:
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id],
)
for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
)
self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear()
@@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler:
request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
request)
self.zmq_rpc_client.call("record_request_cache_and_free_slots", request)
class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config
@@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker:
def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ),
args=(vllm_config,),
)
self.metadata_thread.daemon = True
self.metadata_thread.start()
@@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker:
logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1)
def bind_connector_metadata(
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items():
if req_id in self.requests:
self.requests[req_id].update(req)
req = self.requests[req_id]
else:
self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size,
req.num_computed_tokens // self.block_size):
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size):
self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id]))
@@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None
mla_config: MLAConfig | None = None
if model_config.use_mla:
mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank,
model_config.hf_text_config.qk_rope_head_dim)
model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim
)
self.cpu_kv_caches = list(
self.zmq_rpc_client.call(
"init_cpu_kv_caches",
@@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker:
self.tp_rank,
get_kv_cache_spec(self.vllm_config),
mla_config,
).values())
).values()
)
def start_load_kv(self) -> None:
self.current_layer = 0
@@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker:
cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip(
gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]:
done_sending: set[str] = set()
@@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker:
self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set()
@@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker:
all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread(
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,))
sending_finished_thread.daemon = True
sending_finished_thread.start()
@@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker:
while True:
req_id, req = self.save_input_queue.get()
for i in range(
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) //
self.block_size, len(req.cpu_block_ids))):
save_block_mapping.append(
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)),
):
save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v].
@@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker:
start, step = 0, 1
for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip(
self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True)
self.save_stream.synchronize()
self.save_output_queue.put(req_id)
save_block_mapping.clear()
@@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if vllm_config.cache_config.cache_dtype == "auto":
kv_cache_dtype = vllm_config.model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
vllm_config.cache_config.cache_dtype]
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype]
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
@@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
# using DSA. Fix the spec in vLLM is the final way.
block_size = vllm_config.cache_config.block_size
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=kv_cache_dtype)
block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype
)
elif spec := attn_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec
@@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
if len(mamba_layers) > 0:
if vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
raise NotImplementedError("Prefix caching is not supported for Mamba yet.")
for layer_name, mamba_module in mamba_layers.items():
if spec := mamba_module.get_kv_cache_spec(vllm_config):
kv_cache_spec[layer_name] = spec

View File

@@ -1,9 +1,10 @@
import math
import os
import pickle
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
from typing import Any
import torch
import vllm.envs as envs
@@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.torch_utils import get_dtype_size
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \
CPUKVCacheManager
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager
@dataclass
@@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
@@ -44,7 +43,6 @@ class MetadataServer:
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=None):
if identity is None:
identity = f"worker-{os.getpid()}-{id(self)}"
@@ -56,7 +54,8 @@ class MetadataServer:
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
linger=0,
)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
@@ -74,11 +73,9 @@ class MetadataServer:
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
@@ -86,7 +83,7 @@ class MetadataServer:
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
if hasattr(self, "shared_memory_dict"):
for shm in self.shared_memory_dict.values():
shm.close()
@@ -96,7 +93,8 @@ class MetadataServer:
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB
)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
@@ -105,7 +103,8 @@ class MetadataServer:
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
linger=0,
)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
@@ -133,15 +132,11 @@ class MetadataServer:
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()])
use_mla = isinstance(layer, MLAAttentionSpec)
# mla shares the same kv cache among different tp
if use_mla:
@@ -154,30 +149,24 @@ class MetadataServer:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
for layer_name in kv_cache_specs:
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes
)
if use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
@@ -185,23 +174,20 @@ class MetadataServer:
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
if hasattr(self, "cpu_block_manager"):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks)
self.functions.update(
{
"get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots": self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots,
}
)
def serve_step(self):
client_id = self.socket.recv()
@@ -228,8 +214,7 @@ class MetadataServer:
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
@@ -239,11 +224,9 @@ class MetadataServer:
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None:
return
shutdown_requested = False
@@ -257,7 +240,7 @@ class MetadataServerProc:
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
metadata_server: MetadataServer | None = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")