diff --git a/mypy.ini b/mypy.ini index d093b9b9..a86a89df 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,8 @@ [mypy] ; warn_return_any = True warn_unused_configs = True +; disable errors about unchecked annotations for now. +disable_error_code = annotation-unchecked ; Suppress all missing import errors from torch_npu for mypy. [mypy-torch_npu.*] @@ -31,4 +33,4 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-ucm.*] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index 4b32da9e..08ce30f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,11 +51,6 @@ line-length = 120 # Folder to be modified exclude = [ "tests/**", - # (5) - "vllm_ascend/distributed/kv_transfer/kv_pool/**", - "vllm_ascend/distributed/kv_transfer/utils/**", - "vllm_ascend/kv_offload/**", - "vllm_ascend/lora/**", # (7) "vllm_ascend/quantization/**", "vllm_ascend/sample/*.py", diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py index 7661ef7a..f969bd92 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/ascend_store_connector.py @@ -1,11 +1,10 @@ import threading -from typing import Any, Optional +from typing import Any import torch import zmq from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.forward_context import ForwardContext from vllm.logger import logger from vllm.utils.network_utils import make_zmq_socket @@ -17,31 +16,27 @@ from vllm.v1.request import Request from vllm.v1.serial_utils import MsgpackDecoder from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_scheduler import ( - KVPoolScheduler, get_zmq_rpc_path_lookup) -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import \ - KVPoolWorker + KVPoolScheduler, + get_zmq_rpc_path_lookup, +) +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.pool_worker import KVPoolWorker class AscendStoreConnector(KVConnectorBase_V1): - - def __init__(self, - vllm_config: VllmConfig, - role: KVConnectorRole, - kv_cache_config: Optional[KVCacheConfig] = None): - super().__init__(vllm_config=vllm_config, - role=role, - kv_cache_config=kv_cache_config) + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: KVCacheConfig | None = None): + super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) self.kv_role = vllm_config.kv_transfer_config.kv_role - self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "use_layerwise", False) + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get("use_layerwise", False) self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "consumer_is_to_put", False) + "consumer_is_to_put", False + ) connector_name = vllm_config.kv_transfer_config.kv_connector if connector_name == "MooncakeConnectorStoreV1": logger.warning( - "It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future." + "It is recommended to use the AscendStoreConnector, " + "as the MoonCakeStoreConnector will be removed in the future." ) self.kv_caches: dict[str, torch.Tensor] = {} @@ -49,8 +44,7 @@ class AscendStoreConnector(KVConnectorBase_V1): self.sended_but_unfinished_reqs: set[str] = set() if role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = KVPoolScheduler(vllm_config, - self.use_layerwise) + self.connector_scheduler = KVPoolScheduler(vllm_config, self.use_layerwise) else: self.connector_worker = KVPoolWorker( vllm_config, @@ -59,27 +53,19 @@ class AscendStoreConnector(KVConnectorBase_V1): assert self.connector_worker is not None if vllm_config.parallel_config.rank == 0: - self.lookup_server = LookupKeyServer(self.connector_worker, - vllm_config, - self.use_layerwise) + self.lookup_server = LookupKeyServer(self.connector_worker, vllm_config, self.use_layerwise) ############################################################ # Scheduler Side Methods ############################################################ - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: assert self.connector_scheduler is not None - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): assert self.connector_scheduler is not None - return self.connector_scheduler.update_state_after_alloc( - request, blocks, num_external_tokens) + return self.connector_scheduler.update_state_after_alloc(request, blocks, num_external_tokens) def build_connector_meta( self, @@ -92,7 +78,7 @@ class AscendStoreConnector(KVConnectorBase_V1): self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) @@ -103,8 +89,7 @@ class AscendStoreConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None self.connector_worker.start_load_kv(self._get_connector_metadata()) @@ -113,8 +98,9 @@ class AscendStoreConnector(KVConnectorBase_V1): return self.connector_worker.wait_for_layer_load() - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + ) -> None: if not self.use_layerwise: return @@ -133,17 +119,16 @@ class AscendStoreConnector(KVConnectorBase_V1): self.connector_worker.wait_for_save(self._get_connector_metadata()) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """Get the finished recving and sending requests.""" assert self.connector_worker is not None done_sending, done_recving = self.connector_worker.get_finished( - finished_req_ids, self._get_connector_metadata()) + finished_req_ids, self._get_connector_metadata() + ) return done_sending, done_recving class LookupKeyServer: - def __init__( self, pool_worker: KVPoolWorker, @@ -171,8 +156,7 @@ class LookupKeyServer: token_len = int.from_bytes(all_frames[0], byteorder="big") hash_frames = all_frames[1:] hashes_str = self.decoder.decode(hash_frames) - result = self.pool_worker.lookup_scheduler( - token_len, hashes_str, self.use_layerwise) + result = self.pool_worker.lookup_scheduler(token_len, hashes_str, self.use_layerwise) response = result.to_bytes(4, "big") self.socket.send(response) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py index 3aeccbf3..c0115a00 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/backend.py @@ -4,13 +4,15 @@ from vllm.config import ParallelConfig class Backend(ABC): - + @abstractmethod def __init__(self, parallel_config: ParallelConfig): pass + @abstractmethod def set_device(self): pass + @abstractmethod def register_buffer(self, ptrs: list[int], lengths: list[int]): pass @@ -19,11 +21,9 @@ class Backend(ABC): pass @abstractmethod - def put(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]]): + def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): pass @abstractmethod - def get(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]]): + def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): pass diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py index 4769663b..fc5bc070 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/memcache_backend.py @@ -5,8 +5,7 @@ import torch from vllm.config import ParallelConfig from vllm.logger import logger -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ - Backend +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type @@ -18,7 +17,6 @@ class MmcDirect(Enum): class MemcacheBackend(Backend): - def __init__(self, parallel_config: ParallelConfig): try: from memcache_hybrid import DistributedObjectStore # type: ignore @@ -26,21 +24,17 @@ class MemcacheBackend(Backend): raise ImportError( "Please install memcache by following the instructions at " "https://gitee.com/ascend/memfabric_hybrid " # noqa: E501 - "to run vLLM with MemcacheConnector.") from e + "to run vLLM with MemcacheConnector." + ) from e try: soc_version = get_ascend_device_type() if soc_version in {AscendDeviceType.A2}: import torch from vllm.distributed import get_world_group + tmp_tensor = torch.zeros(1, device="npu") - output_tensor_list = [ - torch.empty_like(tmp_tensor) - for _ in range(torch.distributed.get_world_size()) - ] - torch.distributed.all_gather( - output_tensor_list, - tmp_tensor, - group=get_world_group().device_group) + output_tensor_list = [torch.empty_like(tmp_tensor) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensor_list, tmp_tensor, group=get_world_group().device_group) self.rank = parallel_config.rank self.store = DistributedObjectStore() res = self.store.init(self.rank) @@ -54,8 +48,7 @@ class MemcacheBackend(Backend): logger.error("Configuration loading failed: %s", e) raise except Exception as exc: - logger.error( - "An error occurred while loading the configuration: %s", exc) + logger.error("An error occurred while loading the configuration: %s", exc) raise def set_device(self): @@ -73,22 +66,18 @@ class MemcacheBackend(Backend): def exists(self, keys: list[str]) -> list[int]: return self.store.batch_is_exist(keys) - def get(self, key: list[str], addr: list[list[int]], - size: list[list[int]]): + def get(self, key: list[str], addr: list[list[int]], size: list[list[int]]): try: - res = self.store.batch_get_into_layers(key, addr, size, - MmcDirect.COPY_G2L.value) + res = self.store.batch_get_into_layers(key, addr, size, MmcDirect.COPY_G2L.value) for value in res: if value != 0: logger.error(f"Failed to get key {key},res:{res}") except Exception as e: logger.error(f"Failed to get key {key}. {e}") - def put(self, key: list[str], addr: list[list[int]], - size: list[list[int]]): + def put(self, key: list[str], addr: list[list[int]], size: list[list[int]]): try: - res = self.store.batch_put_from_layers(key, addr, size, - MmcDirect.COPY_L2G.value) + res = self.store.batch_put_from_layers(key, addr, size, MmcDirect.COPY_L2G.value) for value in res: if value != 0: logger.error(f"Failed to get key {key},res:{res}") diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py index 3375e741..0e6a0355 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py @@ -2,10 +2,9 @@ import json import os import re -import torch - from dataclasses import dataclass -from typing import Union + +import torch # Third Party from mooncake.store import ReplicateConfig # type: ignore @@ -13,17 +12,14 @@ from vllm.config import ParallelConfig from vllm.logger import logger from vllm.utils.network_utils import get_ip -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ - Backend -from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import \ - global_te +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend +from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB class MooncakeBackend(Backend): - def __init__(self, parallel_config: ParallelConfig): try: from mooncake.store import MooncakeDistributedStore # type: ignore @@ -31,23 +27,25 @@ class MooncakeBackend(Backend): raise ImportError( "Please install mooncake by following the instructions at " "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 - "to run vLLM with MooncakeConnector.") from e + "to run vLLM with MooncakeConnector." + ) from e self.config = MooncakeStoreConfig.load_from_env() self.store = MooncakeDistributedStore() self.rank = parallel_config.rank if self.config.protocol == "ascend": local_hostname = get_ip() - transfer_engine = global_te.get_transfer_engine(local_hostname, - device_name=None) - self.local_seg = local_hostname + ":" + str( - transfer_engine.get_rpc_port()) - ret = self.store.setup(self.local_seg, self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - transfer_engine.get_engine()) + transfer_engine = global_te.get_transfer_engine(local_hostname, device_name=None) + self.local_seg = local_hostname + ":" + str(transfer_engine.get_rpc_port()) + ret = self.store.setup( + self.local_seg, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine(), + ) if ret != 0: msg = "Initialize mooncake failed." logger.error(msg) @@ -63,25 +61,21 @@ class MooncakeBackend(Backend): def exists(self, keys: list[str]) -> list[int]: return self.store.batch_is_exist(keys) - def put(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]]): + def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: config = ReplicateConfig() config.preferred_segment = self.local_seg config.prefer_alloc_in_same_node = True - res = self.store.batch_put_from_multi_buffers( - keys, addrs, sizes, config) + res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config) for value in res: if value < 0: logger.error(f"Failed to put key {keys},res:{res}") except Exception as e: logger.error(f"Failed to put key {keys},error:{e}") - def get(self, keys: list[str], addrs: list[list[int]], - sizes: list[list[int]]): + def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: - res = self.store.batch_get_into_multi_buffers( - keys, addrs, sizes, True) + res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True) for value in res: if value < 0: logger.error(f"Failed to get key {keys}, res:{res}") @@ -92,7 +86,7 @@ class MooncakeBackend(Backend): @dataclass class MooncakeStoreConfig: metadata_server: str - global_segment_size: Union[int, str] + global_segment_size: int | str local_buffer_size: int protocol: str device_name: str @@ -105,33 +99,32 @@ class MooncakeStoreConfig: return MooncakeStoreConfig( metadata_server=config.get("metadata_server"), global_segment_size=_parse_global_segment_size( - config.get("global_segment_size", - DEFAULT_GLOBAL_SEGMENT_SIZE)), - local_buffer_size=_parse_global_segment_size( - config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)), + config.get("global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE) + ), + local_buffer_size=_parse_global_segment_size(config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)), protocol=config.get("protocol", "ascend"), device_name=config.get("device_name", ""), - master_server_address=config.get("master_server_address")) + master_server_address=config.get("master_server_address"), + ) @staticmethod def load_from_env() -> "MooncakeStoreConfig": config_path = os.getenv("MOONCAKE_CONFIG_PATH") if not config_path: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + raise ValueError("The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") return MooncakeStoreConfig.from_file(config_path) def _parse_global_segment_size(value) -> int: """ Parse storage size strings with support for units: GB, MB, KB, B - + Args: value: Input value (int, str, or other convertible types) - + Returns: int: Size in bytes - + Raises: ValueError: For invalid format, missing number, or negative values TypeError: For unsupported input types @@ -143,54 +136,50 @@ def _parse_global_segment_size(value) -> int: try: return int(value) except (TypeError, ValueError) as e: - raise TypeError( - f"Unsupported type for global_segment_size: {type(value)}" - ) from e + raise TypeError(f"Unsupported type for global_segment_size: {type(value)}") from e cleaned_input = value.strip().lower() if not cleaned_input: raise ValueError("global segment size cannot be empty.") UNIT_MULTIPLIERS = { - 'gb': 1024**3, # 1 GB = 1024^3 bytes - 'mb': 1024**2, # 1 MB = 1024^2 bytes - 'kb': 1024, # 1 KB = 1024 bytes - 'b': 1 # 1 B = 1 byte + "gb": 1024**3, # 1 GB = 1024^3 bytes + "mb": 1024**2, # 1 MB = 1024^2 bytes + "kb": 1024, # 1 KB = 1024 bytes + "b": 1, # 1 B = 1 byte } - pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$' + pattern = r"^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$" match = re.match(pattern, cleaned_input) if not match: raise ValueError(f"Invalid format: '{value}'") number_str = match.group(1) - unit = match.group(2) or 'b' + unit = match.group(2) or "b" multiplier = UNIT_MULTIPLIERS[unit] return _convert_to_bytes(number_str, multiplier, value) -def _convert_to_bytes(number_str: str, multiplier: int, - original_input: str) -> int: +def _convert_to_bytes(number_str: str, multiplier: int, original_input: str) -> int: """ Convert numeric string to byte count - + Args: number_str: Numeric portion of input multiplier: Unit conversion factor original_input: Original input string (for error messages) - + Returns: int: Byte count - + Raises: ValueError: For invalid numbers or negative results """ try: numeric_value = float(number_str) except ValueError: - raise ValueError( - f"Invalid numeric value '{number_str}' in: '{original_input}'") + raise ValueError(f"Invalid numeric value '{number_str}' in: '{original_input}'") # Calculate byte count try: byte_count = int(numeric_value * multiplier) diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py index 676018ed..398cc3fb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/config_data.py @@ -1,16 +1,16 @@ +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable, List, Optional, Tuple, Union +from typing import Optional import torch -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.logger import logger from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import NewRequestData -#Parameters related to the key +# Parameters related to the key @dataclass class KeyMetadata: """name of the LLM model""" @@ -32,23 +32,26 @@ class PoolKey: chunk_hash: str def __hash__(self): - return hash(( - self.key_metadata.model_name, - self.key_metadata.head_or_tp_rank, - self.key_metadata.pcp_rank, - self.key_metadata.dcp_rank, - self.key_metadata.pp_rank, - self.chunk_hash, - )) + return hash( + ( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, + self.key_metadata.pp_rank, + self.chunk_hash, + ) + ) def to_string(self): return ( f"{self.key_metadata.model_name}" f"@pcp{self.key_metadata.pcp_rank}@dcp{self.key_metadata.dcp_rank}" f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}" - f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}") + f"@pp_rank:{self.key_metadata.pp_rank}@{self.chunk_hash}" + ) - def split_layers(self, num_layers: int) -> List["LayerPoolKey"]: + def split_layers(self, num_layers: int) -> list["LayerPoolKey"]: """Split the key into multiple keys for each layer""" keys = [] for layer_id in range(num_layers): @@ -57,7 +60,8 @@ class PoolKey: self.key_metadata, self.chunk_hash, layer_id, - )) + ) + ) return keys @@ -68,14 +72,16 @@ class LayerPoolKey(PoolKey): layer_id: int def __hash__(self): - return hash(( - self.key_metadata.model_name, - self.key_metadata.head_or_tp_rank, - self.key_metadata.pcp_rank, - self.key_metadata.dcp_rank, - self.chunk_hash, - self.layer_id, - )) + return hash( + ( + self.key_metadata.model_name, + self.key_metadata.head_or_tp_rank, + self.key_metadata.pcp_rank, + self.key_metadata.dcp_rank, + self.chunk_hash, + self.layer_id, + ) + ) def to_string(self): return ( @@ -85,10 +91,8 @@ class LayerPoolKey(PoolKey): ) -class ChunkedTokenDatabase(): - - def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, - partitions: Optional[List[int]]): +class ChunkedTokenDatabase: + def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool, partitions: list[int] | None): self.metadata = metadata self.block_size = block_size self.use_mla = use_mla @@ -96,9 +100,7 @@ class ChunkedTokenDatabase(): self.block_len: list[int] = [] self.partitions = partitions - def _make_key_by_hash(self, - chunk_hash: str, - layer_id: Optional[int] = None): + def _make_key_by_hash(self, chunk_hash: str, layer_id: int | None = None): assert self.metadata is not None return PoolKey( self.metadata, @@ -116,8 +118,7 @@ class ChunkedTokenDatabase(): size_list = [] block_id = block_ids[start // self.block_size] for index, base_addr in enumerate(self.kv_caches_base_addr): - block_len = (self.block_len[index % 2] - if self.use_mla else self.block_len[0]) + block_len = self.block_len[index % 2] if self.use_mla else self.block_len[0] addr = base_addr + block_id * block_len length = int(block_len / self.block_size * (end - start)) @@ -125,22 +126,17 @@ class ChunkedTokenDatabase(): size_list.append(length) return addr_list, size_list, block_id - def prepare_value_layer(self, start: int, end: int, block_ids: list[int], - layer_id: int): + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], layer_id: int): block_id = block_ids[start // self.block_size] if self.use_mla: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[1] + addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[1] length_k = int(self.block_len[0] / self.block_size * (end - start)) length_v = int(self.block_len[1] / self.block_size * (end - start)) size_list = [length_k, length_v] else: - addr_k = self.kv_caches_base_addr[layer_id * - 2] + block_id * self.block_len[0] - addr_v = self.kv_caches_base_addr[layer_id * 2 + - 1] + block_id * self.block_len[0] + addr_k = self.kv_caches_base_addr[layer_id * 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + 1] + block_id * self.block_len[0] length = int(self.block_len[0] / self.block_size * (end - start)) size_list = [length, length] addr_list = [addr_k, addr_v] @@ -149,9 +145,9 @@ class ChunkedTokenDatabase(): def process_tokens( self, token_len: int, - block_hashes: Union[list[BlockHash], list[str]], + block_hashes: list[BlockHash] | list[str], mask_num: int = 0, - ) -> Iterable[Tuple[int, int, PoolKey]]: + ) -> Iterable[tuple[int, int, PoolKey]]: """Process the tokens and return the corresponding cache engine keys. :param Union[torch.Tensor, List[int]] tokens: The tokens to process. @@ -202,10 +198,10 @@ class ChunkedTokenDatabase(): start = 0 for j, part in enumerate(self.partitions): # part * 2 because addr and size contain both k and v - end = len(addr_list) if j == len( - self.partitions) - 1 else start + part * 2 + end = len(addr_list) if j == len(self.partitions) - 1 else start + part * 2 new_str = key[i].replace( # type: ignore[attr-defined] - "@pp_rank:0", f"@pp_rank:{j}", 1) + "@pp_rank:0", f"@pp_rank:{j}", 1 + ) new_key.append(new_str) new_addr.append(addr_list[start:end]) new_size.append(size_list[start:end]) @@ -213,7 +209,7 @@ class ChunkedTokenDatabase(): return new_key, new_addr, new_size -#Parameters related to the connector metadata +# Parameters related to the connector metadata @dataclass class LoadSpec: # Number of tokens cached in vLLM @@ -273,7 +269,7 @@ class RequestTracker: def update( self, - new_block_ids: Union[tuple[list[int], ...], list[int]], + new_block_ids: tuple[list[int], ...] | list[int], ) -> None: """Update the request tracker when a running request is scheduled again @@ -286,8 +282,7 @@ class RequestTracker: elif isinstance(new_block_ids, list): pass else: - raise ValueError( - f"Unsupported new_block_ids type {type(new_block_ids)}") + raise ValueError(f"Unsupported new_block_ids type {type(new_block_ids)}") self.allocated_block_ids.extend(new_block_ids) @@ -302,22 +297,22 @@ class ReqMeta: block_hashes: list[BlockHash] - can_save: Optional[bool] = None + can_save: bool | None = None # load_spec - load_spec: Optional[LoadSpec] = None + load_spec: LoadSpec | None = None - is_last_chunk: Optional[bool] = None + is_last_chunk: bool | None = None - current_event: Optional[torch.npu.Event] = None + current_event: torch.npu.Event | None = None @staticmethod def from_request_tracker( tracker: RequestTracker, block_size: int, - load_spec: Optional[LoadSpec] = None, - skip_save: Optional[bool] = False, - block_hashes: list[BlockHash] = [], - is_last_chunk: Optional[bool] = None, + load_spec: LoadSpec | None = None, + skip_save: bool | None = False, + block_hashes: list[BlockHash] | None = None, + is_last_chunk: bool | None = None, discard_partial_chunks: bool = True, ) -> Optional["ReqMeta"]: """Create the request metadata from a request tracker. @@ -333,17 +328,17 @@ class ReqMeta: the request metadata if we need to perform load/save operations, None otherwise. """ + if block_hashes is None: + block_hashes = [] input_token_len = tracker.token_len # For save operation: do not save if the following condition is met # 1. has already been saved before (num_saved_tokens > 0) # 2. number of unsaved tokens is not reached the chunk boundary - chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * - block_size if discard_partial_chunks else 0) + chunk_boundary = cdiv(tracker.num_saved_tokens + 1, block_size) * block_size if discard_partial_chunks else 0 # Calculate number of tokens to save based on discard_partial_chunks # setting - num_tokens_to_save = ((input_token_len // block_size * block_size) - if discard_partial_chunks else input_token_len) + num_tokens_to_save = (input_token_len // block_size * block_size) if discard_partial_chunks else input_token_len skip_save = skip_save or num_tokens_to_save < chunk_boundary if skip_save and load_spec is None: @@ -363,9 +358,7 @@ class ReqMeta: else: # Do not load if not in `can_load` state load_spec = None - logger.debug( - f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}" - ) + logger.debug(f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}") return ReqMeta( req_id=tracker.req_id, token_len_chunk=num_tokens_to_save, @@ -378,7 +371,6 @@ class ReqMeta: class AscendConnectorMetadata(KVConnectorMetadata): - def __init__(self, unfinished_request_ids, preempted_req_ids): self.requests = [] self.unfinished_request_ids = unfinished_request_ids @@ -396,10 +388,10 @@ class AscendConnectorMetadata(KVConnectorMetadata): @dataclass class LasyerMultiBlockReqMeta: req_id: str - keys: List[LayerPoolKey] - starts: List[int] + keys: list[LayerPoolKey] + starts: list[int] ends: list[int] block_ids: list[int] layer_id: int - is_last_chunk: Optional[bool] = True - current_event: Optional[torch.npu.Event] = None + is_last_chunk: bool | None = True + current_event: torch.npu.Event | None = None diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py index 84f289bf..ec8c7041 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/kv_transfer.py @@ -7,8 +7,7 @@ from typing import Any import torch from vllm.logger import logger -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ - Backend +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend # isort: off from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( @@ -20,10 +19,16 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import class KVTransferThread(threading.Thread): - - def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - block_size: int, tp_rank: int, dcp_size: int, - ready_event: threading.Event, name: str): + def __init__( + self, + m_store: Backend, + token_database: ChunkedTokenDatabase, + block_size: int, + tp_rank: int, + dcp_size: int, + ready_event: threading.Event, + name: str, + ): super().__init__(daemon=True, name=name) self.m_store = m_store self.ready_event = ready_event @@ -39,7 +44,7 @@ class KVTransferThread(threading.Thread): def add_request( self, - request: ReqMeta, + request: ReqMeta | LasyerMultiBlockReqMeta, ) -> torch.Tensor: self.request_queue.put(request) @@ -98,17 +103,20 @@ class KVTransferThread(threading.Thread): class KVCacheStoreSendingThread(KVTransferThread): - - def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - block_size: int, tp_rank: int, dcp_size: int, put_step: int, - kv_role: str, ready_event: threading.Event): - super().__init__(m_store, - token_database, - block_size, - tp_rank, - dcp_size, - ready_event, - name="KVCacheSendingThread") + def __init__( + self, + m_store: Backend, + token_database: ChunkedTokenDatabase, + block_size: int, + tp_rank: int, + dcp_size: int, + put_step: int, + kv_role: str, + ready_event: threading.Event, + ): + super().__init__( + m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheSendingThread" + ) self.put_step = put_step self.kv_role = kv_role self.stored_requests = defaultdict[str, int](int) @@ -139,16 +147,15 @@ class KVCacheStoreSendingThread(KVTransferThread): self.request_queue.task_done() return - for start, end, key in self.token_database.process_tokens( - token_len, req_meta.block_hashes): + for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes): starts.append(start) ends.append(end) keys.append(key.to_string()) if not self.dcp_size > 1: - starts = starts[self.tp_rank % self.put_step::self.put_step] - ends = ends[self.tp_rank % self.put_step::self.put_step] - keys = keys[self.tp_rank % self.put_step::self.put_step] + starts = starts[self.tp_rank % self.put_step :: self.put_step] + ends = ends[self.tp_rank % self.put_step :: self.put_step] + keys = keys[self.tp_rank % self.put_step :: self.put_step] if not keys: self.dec_stored_request(req_id) @@ -165,8 +172,7 @@ class KVCacheStoreSendingThread(KVTransferThread): keys = keys[skip_block_num:] logger.info( - "Storing KV cache for %d out of %d blocks " - "(skip_block_num=%d) for request %s", + "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", len(keys), token_len // self.block_size, skip_block_num, @@ -183,14 +189,12 @@ class KVCacheStoreSendingThread(KVTransferThread): addrs = [] sizes = [] for index, start in enumerate(starts): - addr, size, _ = self.token_database.prepare_value( - start, ends[index], block_ids) + addr, size, _ = self.token_database.prepare_value(start, ends[index], block_ids) addrs.append(addr) sizes.append(size) if self.kv_role == "kv_consumer": - keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp( - keys, addrs, sizes) + keys, addrs, sizes = self.token_database.decode_adaptor_prefill_pp(keys, addrs, sizes) if current_event is not None: current_event.synchronize() @@ -201,69 +205,69 @@ class KVCacheStoreSendingThread(KVTransferThread): class KVCacheStoreRecvingThread(KVTransferThread): - - def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - block_size: int, tp_rank: int, dcp_size: int, - ready_event: threading.Event): - super().__init__(m_store, - token_database, - block_size, - tp_rank, - dcp_size, - ready_event, - name="KVCacheStoreRecvingThread") + def __init__( + self, + m_store: Backend, + token_database: ChunkedTokenDatabase, + block_size: int, + tp_rank: int, + dcp_size: int, + ready_event: threading.Event, + ): + super().__init__( + m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreRecvingThread" + ) def _handle_request(self, req_meta: ReqMeta): token_len = req_meta.load_spec.token_len # type: ignore[union-attr] req_id = req_meta.req_id mask_num = ( req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr] - // self.block_size * self.block_size) + // self.block_size + * self.block_size + ) addr_list = [] size_list = [] key_list = [] - for start, end, key in self.token_database.process_tokens( - token_len, req_meta.block_hashes, mask_num): - addr, size, _ = self.token_database.prepare_value( - start, end, req_meta.block_ids) + for start, end, key in self.token_database.process_tokens(token_len, req_meta.block_hashes, mask_num): + addr, size, _ = self.token_database.prepare_value(start, end, req_meta.block_ids) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_c = key_list[self.tp_rank % - len(key_list):] + key_list[:self.tp_rank % - len(key_list)] - addr_list_c = addr_list[self.tp_rank % - len(addr_list):] + addr_list[:self.tp_rank % - len(addr_list)] - size_list_c = size_list[self.tp_rank % - len(size_list):] + size_list[:self.tp_rank % - len(size_list)] + key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)] + addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)] + size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)] self.m_store.get(key_list_c, addr_list_c, size_list_c) self.set_finished_request(req_id) self.request_queue.task_done() class KVCacheStoreLayerSendingThread(KVTransferThread): - - def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - block_size: int, tp_rank: int, dcp_size: int, put_step: int, - ready_event: threading.Event, num_layers: int): - super().__init__(m_store, - token_database, - block_size, - tp_rank, - dcp_size, - ready_event, - name="KVCacheStoreLayerSendingThread") + def __init__( + self, + m_store: Backend, + token_database: ChunkedTokenDatabase, + block_size: int, + tp_rank: int, + dcp_size: int, + put_step: int, + ready_event: threading.Event, + num_layers: int, + ): + super().__init__( + m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerSendingThread" + ) self.final_layer_id = num_layers - 1 self.put_step = put_step def add_request( # type: ignore[override] - self, req_meta: ReqMeta) -> torch.Tensor: + self, req_meta: ReqMeta + ) -> torch.Tensor: self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): + self, req_meta: LasyerMultiBlockReqMeta + ): starts = req_meta.starts ends = req_meta.ends keys = req_meta.keys @@ -272,9 +276,9 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): total_block = len(keys) is_last_chunk = req_meta.is_last_chunk if not self.dcp_size > 1: - starts = starts[self.tp_rank % self.put_step::self.put_step] - ends = ends[self.tp_rank % self.put_step::self.put_step] - keys = keys[self.tp_rank % self.put_step::self.put_step] + starts = starts[self.tp_rank % self.put_step :: self.put_step] + ends = ends[self.tp_rank % self.put_step :: self.put_step] + keys = keys[self.tp_rank % self.put_step :: self.put_step] if not keys: if is_last_chunk: @@ -300,7 +304,8 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): size_list = [] for index, key in enumerate(key_list): addr, size = self.token_database.prepare_value_layer( - starts[index], ends[index], req_meta.block_ids, layer_id) + starts[index], ends[index], req_meta.block_ids, layer_id + ) addr_list.append(addr) size_list.append(size) @@ -313,8 +318,7 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): self.request_queue.task_done() logger.info( - "Storing KV cache for %d out of %d blocks " - "(skip_block_num=%d) for request %s", + "Storing KV cache for %d out of %d blocks (skip_block_num=%d) for request %s", len(keys), total_block, skip_block_num, @@ -323,44 +327,42 @@ class KVCacheStoreLayerSendingThread(KVTransferThread): class KVCacheStoreLayerRecvingThread(KVTransferThread): - - def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase, - block_size: int, tp_rank: int, dcp_size: int, - ready_event: threading.Event, get_event: threading.Event): - super().__init__(m_store, - token_database, - block_size, - tp_rank, - dcp_size, - ready_event, - name="KVCacheStoreLayerRecvingThread") + def __init__( + self, + m_store: Backend, + token_database: ChunkedTokenDatabase, + block_size: int, + tp_rank: int, + dcp_size: int, + ready_event: threading.Event, + get_event: threading.Event, + ): + super().__init__( + m_store, token_database, block_size, tp_rank, dcp_size, ready_event, name="KVCacheStoreLayerRecvingThread" + ) self.get_event = get_event def add_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self, req_meta: LasyerMultiBlockReqMeta + ) -> torch.Tensor: self.request_queue.put(req_meta) def _handle_request( # type: ignore[override] - self, req_meta: LasyerMultiBlockReqMeta): + self, req_meta: LasyerMultiBlockReqMeta + ): addr_list = [] size_list = [] key_list = [] for index, key in enumerate(req_meta.keys): addr, size = self.token_database.prepare_value_layer( - req_meta.starts[index], req_meta.ends[index], - req_meta.block_ids, req_meta.layer_id) + req_meta.starts[index], req_meta.ends[index], req_meta.block_ids, req_meta.layer_id + ) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_c = key_list[self.tp_rank % - len(key_list):] + key_list[:self.tp_rank % - len(key_list)] - addr_list_c = addr_list[self.tp_rank % - len(addr_list):] + addr_list[:self.tp_rank % - len(addr_list)] - size_list_c = size_list[self.tp_rank % - len(size_list):] + size_list[:self.tp_rank % - len(size_list)] + key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)] + addr_list_c = addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)] + size_list_c = size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)] self.m_store.get(key_list_c, addr_list_c, size_list_c) self.request_queue.task_done() diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py index cce4c53d..51e7db70 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_scheduler.py @@ -1,10 +1,9 @@ -from typing import Any, Optional +from typing import Any import vllm.envs as envs import zmq from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import \ - KVConnectorMetadata +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from vllm.logger import logger from vllm.utils.network_utils import make_zmq_socket from vllm.v1.core.kv_cache_manager import KVCacheBlocks @@ -14,27 +13,29 @@ from vllm.v1.request import Request from vllm.v1.serial_utils import MsgpackEncoder from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( - AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker) + AscendConnectorMetadata, + LoadSpec, + ReqMeta, + RequestTracker, +) class KVPoolScheduler: - def __init__(self, vllm_config: "VllmConfig", use_layerwise): self.use_layerwise = use_layerwise self.kv_role = vllm_config.kv_transfer_config.kv_role self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "consumer_is_to_load", False) + "consumer_is_to_load", False + ) self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "consumer_is_to_put", False) - self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "load_async", False) + "consumer_is_to_put", False + ) + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False) self.client = LookupKeyClient(vllm_config) # request_id -> (vllm cached tokes, kvpool cached tokens) self.load_specs: dict[str, LoadSpec] = {} - self.pcp_size = getattr(vllm_config.parallel_config, - "prefill_context_parallel_size", 1) - self.dcp_size = getattr(vllm_config.parallel_config, - "decode_context_parallel_size", 1) + self.pcp_size = getattr(vllm_config.parallel_config, "prefill_context_parallel_size", 1) + self.dcp_size = getattr(vllm_config.parallel_config, "decode_context_parallel_size", 1) self._block_size = vllm_config.cache_config.block_size if self.pcp_size > 1: @@ -45,9 +46,9 @@ class KVPoolScheduler: self._request_trackers: dict[str, RequestTracker] = {} self._preempted_req_ids: set[str] = set() # Whether to discard partial chunks - self._discard_partial_chunks = ( - vllm_config.kv_transfer_config.get_from_extra_config( - "discard_partial_chunks", True)) + self._discard_partial_chunks = vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", True + ) self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} self._unfinished_request_ids: set[str] = set() @@ -72,13 +73,11 @@ class KVPoolScheduler: return 0, False if self._discard_partial_chunks: - token_len = len(request.prompt_token_ids - ) // self._block_size * self._block_size + token_len = len(request.prompt_token_ids) // self._block_size * self._block_size else: token_len = len(request.prompt_token_ids) - num_external_hit_tokens = self.client.lookup(token_len, - request.block_hashes) + num_external_hit_tokens = self.client.lookup(token_len, request.block_hashes) if num_external_hit_tokens == request.num_tokens: num_external_hit_tokens -= 1 @@ -107,9 +106,7 @@ class KVPoolScheduler: return need_to_allocate, self.load_async and not self.use_layerwise - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): """ Update KVConnector state after temporary buffer alloc. @@ -120,8 +117,7 @@ class KVPoolScheduler: if num_external_tokens > 0: local_block_ids = blocks.get_block_ids()[0] - self._unfinished_requests[request.request_id] = (request, - local_block_ids) + self._unfinished_requests[request.request_id] = (request, local_block_ids) self._unfinished_request_ids.add(request.request_id) if request.request_id not in self.load_specs: # No KV tokens from external KV cache, return @@ -133,18 +129,20 @@ class KVPoolScheduler: return assert ( - num_external_tokens > 0 and num_external_tokens - == self.load_specs[request.request_id].kvpool_cached_tokens - - self.load_specs[request.request_id].vllm_cached_tokens - ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + num_external_tokens > 0 + and num_external_tokens + == self.load_specs[request.request_id].kvpool_cached_tokens + - self.load_specs[request.request_id].vllm_cached_tokens + ), ( + f"Mismatch in number of tokens: {num_external_tokens} vs " f"{self.load_specs[request.request_id].kvpool_cached_tokens} - " f"{self.load_specs[request.request_id].vllm_cached_tokens}" - f" for request {request.request_id}") + f" for request {request.request_id}" + ) self.load_specs[request.request_id].can_load = True - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """Attach the connector metadata to the request object. This function should NOT modify other fields in the scheduler_output @@ -155,14 +153,13 @@ class KVPoolScheduler: scheduler_output (SchedulerOutput): the scheduler output object. """ - force_skip_save = (self.kv_role == "kv_consumer" - and not self.consumer_is_to_put) + force_skip_save = self.kv_role == "kv_consumer" and not self.consumer_is_to_put for finished_req_id in scheduler_output.finished_req_ids: self._request_trackers.pop(finished_req_id, None) self._unfinished_requests.pop(finished_req_id, None) self._unfinished_request_ids.discard(finished_req_id) - + for req_id in scheduler_output.preempted_req_ids: self._preempted_req_ids.update(scheduler_output.preempted_req_ids) self._request_trackers.pop(req_id, None) @@ -173,9 +170,7 @@ class KVPoolScheduler: for request in scheduler_output.scheduled_new_reqs: # Right now, we only load KV for new requests load_spec = self.load_specs.pop(request.req_id, None) - num_tokens_to_compute = ( - request.num_computed_tokens + - scheduler_output.num_scheduled_tokens[request.req_id]) + num_tokens_to_compute = request.num_computed_tokens + scheduler_output.num_scheduled_tokens[request.req_id] request_tuple = self._unfinished_requests.get(request.req_id) request_real = request_tuple[0] # type: ignore[index] if not isinstance(request.block_ids[0], list): @@ -183,25 +178,25 @@ class KVPoolScheduler: else: unfolded_block_ids = request.block_ids[0].copy() request_tracker = RequestTracker( - req_id=request.req_id, - token_len=num_tokens_to_compute, - allocated_block_ids=unfolded_block_ids, - num_saved_tokens=0, - ) + req_id=request.req_id, + token_len=num_tokens_to_compute, + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) self._request_trackers[request.req_id] = request_tracker - last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else len( - request.prompt_token_ids)) - + last_chunk_tokens_num = ( + (len(request.prompt_token_ids) // self._block_size * self._block_size) + if self._discard_partial_chunks + else len(request.prompt_token_ids) + ) + req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=load_spec, skip_save=force_skip_save, block_hashes=request_real.block_hashes, - is_last_chunk=request_tracker.token_len - >= last_chunk_tokens_num, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: @@ -224,8 +219,8 @@ class KVPoolScheduler: request_tuple = self._unfinished_requests.get(req_id) request_real = request_tuple[0] # type: ignore[index] num_tokens_to_compute = ( - request_real.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + request_real.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] + ) request_tracker = RequestTracker( req_id=req_id, token_len=num_tokens_to_compute, @@ -233,21 +228,21 @@ class KVPoolScheduler: num_saved_tokens=0, ) self._request_trackers[req_id] = request_tracker - last_chunk_tokens_num = ((len(request_real.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else len( - request_real.prompt_token_ids)) + last_chunk_tokens_num = ( + (len(request_real.prompt_token_ids) // self._block_size * self._block_size) + if self._discard_partial_chunks + else len(request_real.prompt_token_ids) + ) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=load_spec, skip_save=force_skip_save, block_hashes=request_real.block_hashes, - is_last_chunk=request_tracker.token_len - >= last_chunk_tokens_num, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) - + # decode/chunked request else: request_tracker = self._request_trackers[req_id] @@ -256,48 +251,44 @@ class KVPoolScheduler: if req_tuple: request = req_tuple[0] num_current_tokens = request_tracker.token_len - new_token_ids = request.all_token_ids[ - num_current_tokens:num_current_tokens + num_new_tokens] + new_token_ids = request.all_token_ids[num_current_tokens : num_current_tokens + num_new_tokens] request_tracker.token_len += len(new_token_ids) else: raise ValueError( - f"Request {req_id} is not in _unfinished_requests, " - f"but it is scheduled to be cached") + f"Request {req_id} is not in _unfinished_requests, but it is scheduled to be cached" + ) num_computed_token = cached_reqs.num_computed_tokens[i] if num_computed_token >= len(request.prompt_token_ids): continue request_tracker.update(new_block_ids) - last_chunk_tokens_num = ((len(request.prompt_token_ids) // - self._block_size * self._block_size) - if self._discard_partial_chunks else - len(request.prompt_token_ids)) + last_chunk_tokens_num = ( + (len(request.prompt_token_ids) // self._block_size * self._block_size) + if self._discard_partial_chunks + else len(request.prompt_token_ids) + ) req_meta = ReqMeta.from_request_tracker( request_tracker, self._block_size, load_spec=None, skip_save=force_skip_save, block_hashes=request.block_hashes, - is_last_chunk=request_tracker.token_len - >= last_chunk_tokens_num, + is_last_chunk=request_tracker.token_len >= last_chunk_tokens_num, discard_partial_chunks=self._discard_partial_chunks, ) if req_meta is not None: meta.add_request(req_meta) - request_ids = [ - req.req_id for req in scheduler_output.scheduled_new_reqs - ] - for request_id, (request, - block_ids) in self._unfinished_requests.items(): + request_ids = [req.req_id for req in scheduler_output.scheduled_new_reqs] + for request_id, (request, block_ids) in self._unfinished_requests.items(): if request_id not in request_ids and request_id not in cached_reqs.req_ids: load_spec = self.load_specs.pop(request_id, None) if not load_spec: continue num_tokens_to_compute = load_spec.kvpool_cached_tokens - if (num_tokens_to_compute % self._block_size - != 0) and (num_tokens_to_compute - == len(request.prompt_token_ids) - 1): + if (num_tokens_to_compute % self._block_size != 0) and ( + num_tokens_to_compute == len(request.prompt_token_ids) - 1 + ): num_tokens_to_compute = num_tokens_to_compute + 1 request_tracker = RequestTracker( req_id=request_id, @@ -324,7 +315,7 @@ class KVPoolScheduler: self, request: "Request", block_ids: list[int], - ) -> tuple[bool, Optional[dict[str, Any]]]: + ) -> tuple[bool, dict[str, Any] | None]: """ Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later. @@ -336,13 +327,11 @@ class KVPoolScheduler: return False, None delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: - logger.info("Delaying free of %d blocks for request %s", - len(block_ids), request.request_id) + logger.info("Delaying free of %d blocks for request %s", len(block_ids), request.request_id) return delay_free_blocks, None class LookupKeyClient: - def __init__(self, vllm_config: "VllmConfig"): self.encoder = MsgpackEncoder() self.ctx = zmq.Context() # type: ignore[attr-defined] diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index 25080a2b..4e69a112 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -1,37 +1,45 @@ import math import threading -from typing import Dict, Generator, Optional, Type +from collections.abc import Callable, Generator import torch from vllm.config import VllmConfig -from vllm.distributed import (get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, - get_pcp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_pcp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.logger import logger from vllm.v1.core.kv_cache_utils import BlockHash -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import \ - Backend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import \ - MemcacheBackend -from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import \ - MooncakeBackend +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.memcache_backend import MemcacheBackend +from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.mooncake_backend import MooncakeBackend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.config_data import ( - AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata, - LasyerMultiBlockReqMeta, ReqMeta) + AscendConnectorMetadata, + ChunkedTokenDatabase, + KeyMetadata, + LasyerMultiBlockReqMeta, + ReqMeta, +) from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import ( - KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, - KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) + KVCacheStoreLayerRecvingThread, + KVCacheStoreLayerSendingThread, + KVCacheStoreRecvingThread, + KVCacheStoreSendingThread, + KVTransferThread, +) -backend_map: Dict[str, Type[Backend]] = { +backend_map: dict[str, Callable[..., Backend]] = { "mooncake": MooncakeBackend, "memcache": MemcacheBackend, } class KVPoolWorker: - #The main class for the cache engine. + # The main class for the cache engine. def __init__( self, @@ -42,9 +50,7 @@ class KVPoolWorker: parallel_config = vllm_config.parallel_config self.dp_rank = parallel_config.data_parallel_rank self.use_mla = False - if (hasattr(model_config, "use_mla") - and isinstance(model_config.use_mla, bool) - and model_config.use_mla): + if hasattr(model_config, "use_mla") and isinstance(model_config.use_mla, bool) and model_config.use_mla: self.use_mla = True self.use_layerwise = use_layerwize self.tp_rank = get_tensor_model_parallel_rank() @@ -53,19 +59,16 @@ class KVPoolWorker: self.pp_rank = (parallel_config.rank // self.tp_size) % self.pp_size self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() - self.dcp_rank = get_decode_context_model_parallel_rank( - ) if self.dcp_size > 1 else 0 + self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0 self.kv_role = vllm_config.kv_transfer_config.kv_role - self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "load_async", False) + self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get("load_async", False) self.consumer_is_to_put = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "consumer_is_to_put", False) - self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "backend", "mooncake") + "consumer_is_to_put", False + ) + self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get("backend", "mooncake") self.block_size = vllm_config.cache_config.block_size if self.pcp_size > 1: @@ -88,7 +91,7 @@ class KVPoolWorker: self.put_step = 1 self.metadata = KeyMetadata( - model_config.model.rstrip('/').split('/')[-1], + model_config.model.rstrip("/").split("/")[-1], self.head_or_tp_rank, self.pcp_rank, self.dcp_rank, @@ -99,40 +102,28 @@ class KVPoolWorker: if self.kv_role == "kv_consumer" and self.consumer_is_to_put: num_hidden_layers = model_config.hf_text_config.num_hidden_layers partition_list_str = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "prefill_pp_layer_partition", None) - prefill_pp_size = int( - vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "prefill_pp_size", 1)) + "prefill_pp_layer_partition", None + ) + prefill_pp_size = int(vllm_config.kv_transfer_config.kv_connector_extra_config.get("prefill_pp_size", 1)) if partition_list_str is not None: try: - partitions = [ - int(layer) for layer in partition_list_str.split(",") - ] + partitions = [int(layer) for layer in partition_list_str.split(",")] except ValueError as err: - raise ValueError("Invalid partition string: {}".format( - partition_list_str)) from err + raise ValueError("Invalid partition string: {}".format(partition_list_str)) from err if len(partitions) != prefill_pp_size: - raise ValueError( - f"{len(partitions)=} does not match {prefill_pp_size=}." - ) + raise ValueError(f"{len(partitions)=} does not match {prefill_pp_size=}.") if sum(partitions) != num_hidden_layers: - raise ValueError( - f"{sum(partitions)=} does not match {num_hidden_layers=}." - ) + raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.") else: layers_per_partition = num_hidden_layers // prefill_pp_size - partitions = [ - layers_per_partition for _ in range(prefill_pp_size) - ] + partitions = [layers_per_partition for _ in range(prefill_pp_size)] if remaining_layers := num_hidden_layers % prefill_pp_size: for i in range(2, remaining_layers + 2): partitions[-i] += 1 - self.token_database = ChunkedTokenDatabase(self.metadata, - self.block_size, - self.use_mla, partitions) + self.token_database = ChunkedTokenDatabase(self.metadata, self.block_size, self.use_mla, partitions) real_backend = backend_map.get(self.backend.lower()) @@ -142,10 +133,11 @@ class KVPoolWorker: self.put_step = 1 self.m_store = real_backend( # type: ignore[misc] - parallel_config) + parallel_config + ) - self.kv_send_thread: Optional[KVTransferThread] = None - self.kv_recv_thread: Optional[KVTransferThread] = None + self.kv_send_thread: KVTransferThread | None = None + self.kv_recv_thread: KVTransferThread | None = None self.finished_store_req: set[str] = set() @@ -162,11 +154,14 @@ class KVPoolWorker: block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] self.block_len = [ first_kv_cache[0].element_size() * math.prod(block_shape_norm), - first_kv_cache[1].element_size() * math.prod(block_shape_pe) + first_kv_cache[1].element_size() * math.prod(block_shape_pe), ] logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", - self.num_blocks, block_shape_norm, block_shape_pe) + self.num_blocks, + block_shape_norm, + block_shape_pe, + ) else: # [num_block, block_size, num_head, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -174,11 +169,9 @@ class KVPoolWorker: block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] self.block_len = [kv_elem_size * math.prod(block_shape)] - logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, - block_shape) + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) - logger.info("Registering KV_Caches. use_mla: %s, shape %s", - self.use_mla, first_kv_cache.shape) + logger.info("Registering KV_Caches. use_mla: %s, shape %s", self.use_mla, first_kv_cache.shape) self.kv_caches = kv_caches self.kv_caches_base_addr = [] @@ -194,8 +187,7 @@ class KVPoolWorker: ptrs.append(base_addr) lengths.append(region_len) else: - cache_list = [cache_or_caches - ] if self.use_mla else cache_or_caches + cache_list = [cache_or_caches] if self.use_mla else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() self.kv_caches_base_addr.append(base_addr) @@ -208,33 +200,50 @@ class KVPoolWorker: if self.use_layerwise: self.get_event = threading.Event() - if self.kv_role in ['kv_producer', 'kv_both']: + if self.kv_role in ["kv_producer", "kv_both"]: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreLayerSendingThread( - self.m_store, self.token_database, self.block_size, - self.tp_rank, self.dcp_size, self.put_step, - ready_event_sending, self.num_layers) + self.m_store, + self.token_database, + self.block_size, + self.tp_rank, + self.dcp_size, + self.put_step, + ready_event_sending, + self.num_layers, + ) self.kv_send_thread.start() ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreLayerRecvingThread( - self.m_store, self.token_database, self.block_size, - self.tp_rank, self.dcp_size, ready_event, self.get_event) + self.m_store, + self.token_database, + self.block_size, + self.tp_rank, + self.dcp_size, + ready_event, + self.get_event, + ) self.kv_recv_thread.start() ready_event.wait() else: - if self.kv_role in ['kv_producer', 'kv_both' - ] or self.consumer_is_to_put: + if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put: ready_event_sending = threading.Event() self.kv_send_thread = KVCacheStoreSendingThread( - self.m_store, self.token_database, self.block_size, - self.tp_rank, self.dcp_size, self.put_step, self.kv_role, - ready_event_sending) + self.m_store, + self.token_database, + self.block_size, + self.tp_rank, + self.dcp_size, + self.put_step, + self.kv_role, + ready_event_sending, + ) self.kv_send_thread.start() if self.load_async: ready_event = threading.Event() self.kv_recv_thread = KVCacheStoreRecvingThread( - self.m_store, self.token_database, self.block_size, - self.tp_rank, self.dcp_size, ready_event) + self.m_store, self.token_database, self.block_size, self.tp_rank, self.dcp_size, ready_event + ) self.kv_recv_thread.start() ready_event.wait() @@ -243,12 +252,12 @@ class KVPoolWorker: self.layerwise_retrievers = [] for request in metadata.requests: load_spec = request.load_spec - if load_spec is None or not load_spec.can_load: #load =0 + if load_spec is None or not load_spec.can_load: # load =0 continue token_len = request.token_len_chunk - if (load_spec.kvpool_cached_tokens % self.block_size - != 0) and (load_spec.kvpool_cached_tokens - == token_len - 1): + if (load_spec.kvpool_cached_tokens % self.block_size != 0) and ( + load_spec.kvpool_cached_tokens == token_len - 1 + ): token_len = request.load_spec.kvpool_cached_tokens + 1 else: token_len = request.load_spec.kvpool_cached_tokens @@ -260,30 +269,27 @@ class KVPoolWorker: else: if self.load_async: self.kv_recv_thread.add_request( # type: ignore[union-attr] - request, ) + request, + ) else: addr_list = [] size_list = [] key_list = [] - mask_num = (request.load_spec.vllm_cached_tokens // - self.block_size * self.block_size) + mask_num = request.load_spec.vllm_cached_tokens // self.block_size * self.block_size for start, end, key in self.token_database.process_tokens( - token_len, request.block_hashes, mask_num): - addr, size, _ = self.token_database.prepare_value( - start, end, request.block_ids) + token_len, request.block_hashes, mask_num + ): + addr, size, _ = self.token_database.prepare_value(start, end, request.block_ids) key_list.append(key.to_string()) addr_list.append(addr) size_list.append(size) - key_list_c = key_list[self.tp_rank % len( - key_list):] + key_list[:self.tp_rank % len(key_list)] - addr_list_c = addr_list[self.tp_rank % - len(addr_list - ):] + addr_list[:self.tp_rank % - len(addr_list)] - size_list_c = size_list[self.tp_rank % - len(size_list - ):] + size_list[:self.tp_rank % - len(size_list)] + key_list_c = key_list[self.tp_rank % len(key_list) :] + key_list[: self.tp_rank % len(key_list)] + addr_list_c = ( + addr_list[self.tp_rank % len(addr_list) :] + addr_list[: self.tp_rank % len(addr_list)] + ) + size_list_c = ( + size_list[self.tp_rank % len(size_list) :] + size_list[: self.tp_rank % len(size_list)] + ) self.m_store.get(key_list_c, addr_list_c, size_list_c) def wait_for_layer_load(self) -> None: @@ -294,8 +300,7 @@ class KVPoolWorker: num_retrieved_tokens = ret_token_mask.sum().item() logger.debug(f"Retrieved {num_retrieved_tokens} tokens") - def save_kv_layer(self, - connector_metadata: AscendConnectorMetadata) -> None: + def save_kv_layer(self, connector_metadata: AscendConnectorMetadata) -> None: if self.current_layer == 0: self.layerwise_storers = [] current_event = None @@ -336,15 +341,17 @@ class KVPoolWorker: continue request.current_event = current_event - self.kv_send_thread.add_stored_request( # type: ignore[union-attr] - request.req_id) + self.kv_send_thread.add_stored_request( # type: ignore[union-attr] + request.req_id + ) self.kv_send_thread.add_request( # type: ignore[union-attr] - request, ) + request, + ) def retrieve_layer( self, request: ReqMeta, - ) -> Generator[Optional[torch.Tensor], None, None]: + ) -> Generator[torch.Tensor | None, None, None]: """ Retrieve the KV cache in a layerwise manner. @@ -359,12 +366,14 @@ class KVPoolWorker: return: A generator that yields Optional[torch.Tensor]. The tensor will be the boolean mask indicating which tokens are retrieved and will - only be returned in the last iteration. + only be returned in the last iteration. """ token_len = request.token_len_chunk mask_num = ( request.load_spec.vllm_cached_tokens # type: ignore[union-attr] - // self.block_size * self.block_size) + // self.block_size + * self.block_size + ) num_required_tokens = token_len - mask_num ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu") @@ -373,8 +382,7 @@ class KVPoolWorker: ends = [] keys = [] first_flag = True - for start, end, key in self.token_database.process_tokens( - token_len, request.block_hashes, mask_num): + for start, end, key in self.token_database.process_tokens(token_len, request.block_hashes, mask_num): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) @@ -386,16 +394,16 @@ class KVPoolWorker: keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] for layer_id, keys_multi_chunk in enumerate(keys): if not first_flag: - is_finish = self.get_event.wait(timeout=3) #try---cache + is_finish = self.get_event.wait(timeout=3) # try---cache if not is_finish: logger.info("Layerwise get failed") self.get_event.clear() - req_meta = LasyerMultiBlockReqMeta(request.req_id, - keys_multi_chunk, starts, - ends, request.block_ids, - layer_id) + req_meta = LasyerMultiBlockReqMeta( + request.req_id, keys_multi_chunk, starts, ends, request.block_ids, layer_id + ) self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] - req_meta) # type: ignore[union-attr, call-arg, arg-type] + req_meta + ) # type: ignore[union-attr, call-arg, arg-type] first_flag = False yield None else: @@ -405,16 +413,14 @@ class KVPoolWorker: yield None retrieved_tokens = torch.sum(ret_mask) - logger.debug(f"Retrieved {retrieved_tokens} " - f"out of {num_required_tokens} " - f"out of total {token_len} tokens") + logger.debug(f"Retrieved {retrieved_tokens} out of {num_required_tokens} out of total {token_len} tokens") yield ret_mask def store_layer( self, request: ReqMeta, - current_event: Optional[torch.npu.Event], + current_event: torch.npu.Event | None, ) -> Generator[None, None, None]: """ Store the KV cache in a layerwise manner. @@ -439,69 +445,88 @@ class KVPoolWorker: starts = [] ends = [] keys = [] - for start, end, key in self.token_database.process_tokens( - request.token_len_chunk, request.block_hashes): + for start, end, key in self.token_database.process_tokens(request.token_len_chunk, request.block_hashes): keys_multi_layer = key.split_layers(self.num_layers) starts.append(start) ends.append(end) - keys.append(keys_multi_layer) #[block_num,layer_num] + keys.append(keys_multi_layer) # [block_num,layer_num] if keys: - keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] + keys = [list(row) for row in zip(*keys)] # [layer_num,block_num] for layer_id, keys_multi_chunk in enumerate(keys): - req_meta = LasyerMultiBlockReqMeta(request.req_id, - keys_multi_chunk, starts, - ends, request.block_ids, - layer_id, - request.is_last_chunk, - current_event) + req_meta = LasyerMultiBlockReqMeta( + request.req_id, + keys_multi_chunk, + starts, + ends, + request.block_ids, + layer_id, + request.is_last_chunk, + current_event, + ) self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] - req_meta) # type: ignore[union-attr, call-arg, arg-type] + req_meta + ) # type: ignore[union-attr, call-arg, arg-type] yield else: for layer_id in range(self.num_layers): yield - def get_finished(self, - finished_req_ids: set[str], meta:AscendConnectorMetadata) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str], meta: AscendConnectorMetadata) -> tuple[set[str], set[str]]: done_sending = ( self.get_and_clear_finished_requests( - finished_req_ids, meta # type: ignore[union-attr] - ) if self.kv_role in ['kv_producer', 'kv_both'] - or self.consumer_is_to_put else set()) + finished_req_ids, + meta, # type: ignore[union-attr] + ) + if self.kv_role in ["kv_producer", "kv_both"] or self.consumer_is_to_put + else set() + ) done_recving = ( - self.kv_recv_thread. - get_and_clear_finished_requests( # type: ignore[union-attr] - ) if self.load_async else set()) + self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + if self.load_async + else set() + ) logger.debug( - "Number of completed KV cache send requests: %d, receive " - "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), - self.tp_rank) + "Number of completed KV cache send requests: %d, receive requests: %d, tp_rank:%d", + len(done_sending), + len(done_recving), + self.tp_rank, + ) return done_sending, done_recving - def get_and_clear_finished_requests(self, finished_req_ids, meta:AscendConnectorMetadata) -> set[str]: + def get_and_clear_finished_requests(self, finished_req_ids, meta: AscendConnectorMetadata) -> set[str]: finished_sending = set() for req_id in meta.preempted_req_ids: self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] - req_id) + req_id + ) for req_id in self.kv_send_thread.stored_requests.copy( # type: ignore[union-attr] ): - if self.kv_send_thread.stored_requests[ # type: ignore[union-attr] - req_id] == 0 and req_id in self.finished_store_req: + if ( + self.kv_send_thread.stored_requests[ # type: ignore[union-attr] + req_id + ] + == 0 + and req_id in self.finished_store_req + ): self.finished_store_req.remove(req_id) finished_sending.add(req_id) self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] - req_id) + req_id + ) for req_id in finished_req_ids: req_remain_jobs = self.kv_send_thread.stored_requests.get( # type: ignore[union-attr] - req_id) + req_id + ) if req_remain_jobs == 0: finished_sending.add(req_id) self.kv_send_thread.delete_finished_stored_request( # type: ignore[union-attr] - req_id) + req_id + ) elif req_remain_jobs is not None: self.finished_store_req.add(req_id) @@ -522,8 +547,7 @@ class KVPoolWorker: keys = [] try: starts = [] - for start, end, key in self.token_database.process_tokens( - token_len, block_hashes): + for start, end, key in self.token_database.process_tokens(token_len, block_hashes): if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: @@ -560,8 +584,7 @@ class KVPoolWorker: keys = [] try: starts = [] - for start, end, key in self.token_database.process_tokens( - token_len, block_hashes): + for start, end, key in self.token_database.process_tokens(token_len, block_hashes): if use_layerwise: keys_multi_layer = key.split_layers(self.num_layers) for item in keys_multi_layer: @@ -574,25 +597,25 @@ class KVPoolWorker: for i in range(1, min(self.tp_size, self.num_kv_head)): for item in keys: new_str = item.replace( # type: ignore[attr-defined] - "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1) + "@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1 + ) multi_tp_keys.append(new_str) for i in range(1, self.pp_size): for item in keys: new_str = item.replace( # type: ignore[attr-defined] - "@pp_rank:0", f"@pp_rank:{i}", 1) + "@pp_rank:0", f"@pp_rank:{i}", 1 + ) multi_tp_keys.append(new_str) - res = self.m_store.exists( - multi_tp_keys) # type: ignore[assignment] + res = self.m_store.exists(multi_tp_keys) # type: ignore[assignment] num_block = len(keys) if use_layerwise: res = self.check_all_layers_exists(res, self.num_layers) num_block = len(keys) // self.num_layers multi_tp_values = [ - res[i * num_block:(i + 1) * num_block] # type: ignore[index] - for i in range( - min(self.tp_size, self.num_kv_head) * self.pp_size) + res[i * num_block : (i + 1) * num_block] # type: ignore[index] + for i in range(min(self.tp_size, self.num_kv_head) * self.pp_size) ] index = self.find_min_first_non_one_index(multi_tp_values) if index != -1: @@ -603,8 +626,7 @@ class KVPoolWorker: return start return end - def check_all_layers_exists(self, res: list[int], - num_layers: int) -> list[int]: + def check_all_layers_exists(self, res: list[int], num_layers: int) -> list[int]: total_chunks = len(res) // num_layers result = [] @@ -618,7 +640,6 @@ class KVPoolWorker: def find_min_first_non_one_index(self, arr): try: - return min(idx for row in arr for idx, val in enumerate(row) - if val != 1) + return min(idx for row in arr for idx, val in enumerate(row) if val != 1) except ValueError: return -1 diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py index 24307f5f..ecf1f81c 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_kv_cache_manager.py @@ -1,20 +1,17 @@ import time from collections import defaultdict -from typing import Optional from vllm.logger import logger from vllm.utils.hashing import sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock) -from vllm.v1.core.single_type_kv_cache_manager import \ - get_manager_for_kv_cache_spec +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.single_type_kv_cache_manager import get_manager_for_kv_cache_spec from vllm.v1.kv_cache_interface import KVCacheSpec -from vllm.v1.metrics.stats import (PrefixCacheStats, CachingMetrics) +from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats from vllm.v1.request import Request class CPUCacheStats: - def __init__(self, enable_prefix_caching: bool, log_stats: bool = False): self.enable_prefix_caching = enable_prefix_caching self.log_stats = log_stats @@ -27,10 +24,9 @@ class CPUCacheStats: # Log the prefix cache hit rate every 10 seconds. if current_time_sec - self.time_sec >= 10: self.time_sec = current_time_sec - logger.info("CPU Prefix cache hit rate: %.1f%%", - self.cpu_prefix_cache_metrics.hit_rate * 100) + logger.info("CPU Prefix cache hit rate: %.1f%%", self.cpu_prefix_cache_metrics.hit_rate * 100) - def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + def make_prefix_cache_stats(self) -> PrefixCacheStats | None: """Get (and reset) the prefix cache stats. Returns: The current prefix caching stats, or None if logging is disabled. @@ -57,7 +53,6 @@ class CPUCacheStats: class CPUKVCacheManager: - def __init__( self, kv_cache_spec: KVCacheSpec, @@ -70,30 +65,26 @@ class CPUKVCacheManager: self.num_cpu_blocks = num_cpu_blocks self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash self.use_eagle = use_eagle - self.block_pool = BlockPool(self.num_cpu_blocks, True, - enable_kv_cache_events) + self.block_pool = BlockPool(self.num_cpu_blocks, True, enable_kv_cache_events) self.single_type_manager = get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=0, ) # Record kv block hashes, avoid redundant computation. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHash]] = defaultdict(list) + self.req_to_block_hashes: defaultdict[str, list[BlockHash]] = defaultdict(list) # Record blocks touched in get_matched_num_and_touch(). - self.req_to_computed_blocks: defaultdict[ - str, list[KVCacheBlock]] = defaultdict(list) + self.req_to_computed_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list) # Record the request that failed to allocate. self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool) self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int) - self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, - log_stats=True) + self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, log_stats=True) # Record request that will be free after finish sending self.req_to_free: defaultdict[str, Request] = defaultdict(Request) def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]: # When the request requires prompt logprobs, we skip prefix caching. - if (request.sampling_params.prompt_logprobs is not None): + if request.sampling_params.prompt_logprobs is not None: return 0, False request_id = request.request_id # The block hashes for the request may already be computed @@ -119,10 +110,8 @@ class CPUKVCacheManager: # cup prefix cache status set and log assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None - self.cpu_cache_stats.set_cache_stats(request.num_tokens, - num_computed_tokens) - self.cpu_cache_stats.cpu_prefix_cache_metrics.observe( - self.cpu_cache_stats.prefix_cache_stats) + self.cpu_cache_stats.set_cache_stats(request.num_tokens, num_computed_tokens) + self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(self.cpu_cache_stats.prefix_cache_stats) self.cpu_cache_stats.log() return num_computed_tokens, False @@ -130,12 +119,10 @@ class CPUKVCacheManager: def _release_ahead_touch(self, request_id: str): computed_blocks = self.req_to_computed_blocks[request_id] if computed_blocks: - self.single_type_manager.block_pool.free_blocks( - reversed(computed_blocks)) + self.single_type_manager.block_pool.free_blocks(reversed(computed_blocks)) self.req_to_computed_blocks.pop(request_id, None) - def allocate_slots(self, req_to_num_tokens: dict[str, int], - unallocated_req_ids: set[str]) -> dict[str, list[int]]: + def allocate_slots(self, req_to_num_tokens: dict[str, int], unallocated_req_ids: set[str]) -> dict[str, list[int]]: for request_id in unallocated_req_ids: self._free_slots(request_id) req_to_new_blocks = {} @@ -143,44 +130,34 @@ class CPUKVCacheManager: if self.req_failed_to_allocate[request_id]: continue new_computed_blocks = self.req_to_computed_blocks[request_id] - num_blocks_to_allocate = ( - self.single_type_manager.get_num_blocks_to_allocate( - request_id=request_id, - num_tokens=num_tokens, - new_computed_blocks=new_computed_blocks, - )) + num_blocks_to_allocate = self.single_type_manager.get_num_blocks_to_allocate( + request_id=request_id, + num_tokens=num_tokens, + new_computed_blocks=new_computed_blocks, + ) if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): self._release_ahead_touch(request_id) self.req_failed_to_allocate[request_id] = True continue # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - self.single_type_manager.save_new_computed_blocks( - request_id, new_computed_blocks) + self.single_type_manager.save_new_computed_blocks(request_id, new_computed_blocks) # Allocate new blocks but do not cache now. - new_blocks = self.single_type_manager.allocate_new_blocks( - request_id, num_tokens) + new_blocks = self.single_type_manager.allocate_new_blocks(request_id, num_tokens) self.req_to_num_tokens[request_id] = num_tokens # No need to release ref_cnt because we use officially. self.req_to_computed_blocks.pop(request_id, None) - req_to_new_blocks[request_id] = [ - block.block_id for block in new_computed_blocks + new_blocks - ] + req_to_new_blocks[request_id] = [block.block_id for block in new_computed_blocks + new_blocks] return req_to_new_blocks def record_request_cache_and_free_slots(self, request: Request): - logger.debug( - f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager" - ) + logger.debug(f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager") self.req_to_free[request.request_id] = request def cache_and_free_slots(self, request_id: str): - logger.debug( - f"Cache and free slots for request {request_id} in cpu_kv_cache_manager" - ) + logger.debug(f"Cache and free slots for request {request_id} in cpu_kv_cache_manager") if request_id not in self.req_to_free: - logger.Error( - f"request {request_id} not in req_to_free, maybe bug!") + logger.Error(f"request {request_id} not in req_to_free, maybe bug!") return request = self.req_to_free[request_id] if not self.req_failed_to_allocate[request_id]: @@ -189,8 +166,7 @@ class CPUKVCacheManager: self.req_to_num_tokens[request_id], ) self._free_slots(request_id) - logger.debug( - f"delete request {request_id} in cpu_kv_cache_manager req_to_free") + logger.debug(f"delete request {request_id} in cpu_kv_cache_manager req_to_free") del self.req_to_free[request_id] def _free_slots(self, request_id: str): diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py index 7e771b3a..c9d2cc1d 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/cpu_offload_connector.py @@ -5,15 +5,15 @@ import queue import threading import time from collections import defaultdict +from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional import torch from vllm.attention.layer import Attention, MLAAttention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -23,12 +23,14 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import ( - MetadataServer, MetadataServerProc, MLAConfig) - + MetadataServer, + MetadataServerProc, + MLAConfig, +) if TYPE_CHECKING: - from vllm.v1.attention.backend import AttentionMetadata #type: ignore from vllm.forward_context import ForwardContext + from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request @@ -59,20 +61,15 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata): class CPUOffloadingConnector(KVConnectorBase_V1): - - def __init__(self, - vllm_config: VllmConfig, - role: KVConnectorRole, - kv_cache_config: Optional["KVCacheConfig"] = None): + def __init__( + self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None + ): self._connector_metadata = CPUOffloadingConnectorMetadata(requests={}, finished_req_ids=set()) if not vllm_config.cache_config.enable_prefix_caching: - self.connector_scheduler: Optional[ - CPUOffloadingConnectorScheduler] = None - self.connector_worker: Optional[ - CPUOffloadingConnectorWorker] = None + self.connector_scheduler: CPUOffloadingConnectorScheduler | None = None + self.connector_worker: CPUOffloadingConnectorWorker | None = None elif role == KVConnectorRole.SCHEDULER: - self.connector_scheduler = CPUOffloadingConnectorScheduler( - vllm_config) + self.connector_scheduler = CPUOffloadingConnectorScheduler(vllm_config) self.connector_worker = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None @@ -82,11 +79,9 @@ class CPUOffloadingConnector(KVConnectorBase_V1): # Worker-side methods # ============================== - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: if self.connector_worker is not None: - assert isinstance(connector_metadata, - CPUOffloadingConnectorMetadata) + assert isinstance(connector_metadata, CPUOffloadingConnectorMetadata) self.connector_worker.bind_connector_metadata(connector_metadata) def clear_connector_metadata(self) -> None: @@ -97,8 +92,7 @@ class CPUOffloadingConnector(KVConnectorBase_V1): if self.connector_worker is not None: self.connector_worker.register_kv_caches(kv_caches) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: if self.connector_worker is not None: self.connector_worker.start_load_kv() @@ -106,53 +100,42 @@ class CPUOffloadingConnector(KVConnectorBase_V1): if self.connector_worker is not None: self.connector_worker.wait_for_layer_load() - def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, - attn_metadata: "AttentionMetadata", **kwargs) -> None: + def save_kv_layer( + self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + ) -> None: pass def wait_for_save(self): pass - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str] | None, set[str] | None]: assert self.connector_worker is not None return self.connector_worker.get_finished(), None # Scheduler-side methods # ============================== - def get_num_new_matched_tokens( - self, request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int, bool]: if self.connector_scheduler is not None: - return self.connector_scheduler.get_num_new_matched_tokens( - request, num_computed_tokens) + return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) return 0, False - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int): + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): if self.connector_scheduler is not None: return self.connector_scheduler.update_state_after_alloc(request) - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: if self.connector_scheduler is not None: - return self.connector_scheduler.build_connector_meta( - scheduler_output) + return self.connector_scheduler.build_connector_meta(scheduler_output) return KVConnectorMetadata() - def request_finished( - self, request: "Request", - block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]: + def request_finished(self, request: "Request", block_ids: list[int]) -> tuple[bool, dict[str, Any] | None]: if self.connector_scheduler is not None: self.connector_scheduler.request_finished(request) return True, None class CPUOffloadingConnectorScheduler: - def __init__(self, vllm_config: VllmConfig): logger.info("init CPUOffloadingConnectorScheduler") self.vllm_config = vllm_config @@ -165,22 +148,17 @@ class CPUOffloadingConnectorScheduler: self.zmq_rpc_client = MetadataServer.ZMQRPCClient() self.zmq_rpc_client.call("post_init") if vllm_config.kv_transfer_config is not None: - self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( - "swap_in_threshold", 0) + self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config("swap_in_threshold", 0) else: self.swap_in_threshold = 0 logger.info(f"swap_in_threshold: {self.swap_in_threshold}") - def get_num_new_matched_tokens( - self, ori_request: "Request", - num_computed_tokens: int) -> tuple[int, bool]: + def get_num_new_matched_tokens(self, ori_request: "Request", num_computed_tokens: int) -> tuple[int, bool]: request = copy.deepcopy(ori_request) request.get_hash_new_full_blocks = None - num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( - "get_matched_num_and_touch", request) + num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call("get_matched_num_and_touch", request) self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens - self.num_cpu_computed_tokens[ - request.request_id] = num_cpu_computed_tokens + self.num_cpu_computed_tokens[request.request_id] = num_cpu_computed_tokens if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: return num_cpu_computed_tokens - num_computed_tokens, load_async else: @@ -189,29 +167,22 @@ class CPUOffloadingConnectorScheduler: def update_state_after_alloc(self, request: "Request"): self.allocated_req_ids.add(request.request_id) - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: num_tokens = {} # process scheduled_new_reqs for req in scheduler_output.scheduled_new_reqs: req_id = req.req_id - num_tokens[req_id] = ( - req.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + num_tokens[req_id] = req.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id] # process scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs for idx, req_id in enumerate(cached_reqs.req_ids): - num_tokens[req_id] = ( - cached_reqs.num_computed_tokens[idx] + - scheduler_output.num_scheduled_tokens[req_id]) + num_tokens[req_id] = cached_reqs.num_computed_tokens[idx] + scheduler_output.num_scheduled_tokens[req_id] - unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - - self.allocated_req_ids - - scheduler_output.num_scheduled_tokens.keys()) - new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", - num_tokens, - unallocated_req_ids) + unallocated_req_ids = set( + self.num_gpu_computed_tokens.keys() - self.allocated_req_ids - scheduler_output.num_scheduled_tokens.keys() + ) + new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", num_tokens, unallocated_req_ids) metadata = CPUOffloadingConnectorMetadata( requests={}, finished_req_ids=set(self.finished_req_ids), @@ -222,22 +193,22 @@ class CPUOffloadingConnectorScheduler: metadata.requests[req_id] = ReqMeta( gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, cpu_block_ids=new_cpu_block_ids.get(req_id, []), - num_scheduled_tokens=scheduler_output. - num_scheduled_tokens[req_id], + num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id], num_computed_tokens=req.num_computed_tokens, num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], - num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) + num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id], + ) for idx, req_id in enumerate(cached_reqs.req_ids): gpu_block_ids = cached_reqs.new_block_ids[idx] metadata.requests[req_id] = ReqMeta( gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, cpu_block_ids=new_cpu_block_ids.get(req_id, []), - num_scheduled_tokens=scheduler_output. - num_scheduled_tokens[req_id], + num_scheduled_tokens=scheduler_output.num_scheduled_tokens[req_id], num_computed_tokens=cached_reqs.num_computed_tokens[idx], num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx], - num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx]) + num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx], + ) self.num_gpu_computed_tokens.clear() self.num_cpu_computed_tokens.clear() self.allocated_req_ids.clear() @@ -249,12 +220,10 @@ class CPUOffloadingConnectorScheduler: request.get_hash_new_full_blocks = None self.finished_req_ids.append(request.request_id) # inform metadata server to record request, and free it after finish sending - self.zmq_rpc_client.call("record_request_cache_and_free_slots", - request) + self.zmq_rpc_client.call("record_request_cache_and_free_slots", request) class CPUOffloadingConnectorWorker: - def __init__(self, vllm_config: VllmConfig): logger.info("init CPUOffloadingConnectorWorker") self.vllm_config = vllm_config @@ -289,7 +258,7 @@ class CPUOffloadingConnectorWorker: def init_metadata_server(self, vllm_config: VllmConfig): self.metadata_thread = threading.Thread( target=MetadataServerProc.run_metadata_server, - args=(vllm_config, ), + args=(vllm_config,), ) self.metadata_thread.daemon = True self.metadata_thread.start() @@ -304,18 +273,15 @@ class CPUOffloadingConnectorWorker: logger.info(f"wait for metadata server to start, error: {e}") time.sleep(1) - def bind_connector_metadata( - self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: for req_id, req in connector_metadata.requests.items(): if req_id in self.requests: self.requests[req_id].update(req) req = self.requests[req_id] else: self.requests[req_id] = req - for i in range(req.num_gpu_computed_tokens // self.block_size, - req.num_computed_tokens // self.block_size): - self.load_block_mapping.append( - (req.cpu_block_ids[i], req.gpu_block_ids[i])) + for i in range(req.num_gpu_computed_tokens // self.block_size, req.num_computed_tokens // self.block_size): + self.load_block_mapping.append((req.cpu_block_ids[i], req.gpu_block_ids[i])) for req_id in connector_metadata.finished_req_ids: if req_id in self.requests: self.save_input_queue.put((req_id, self.requests[req_id])) @@ -326,11 +292,11 @@ class CPUOffloadingConnectorWorker: def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): self.gpu_kv_caches = kv_caches model_config = self.vllm_config.model_config - mla_config: Optional[MLAConfig] = None + mla_config: MLAConfig | None = None if model_config.use_mla: mla_config = MLAConfig( - model_config.hf_text_config.kv_lora_rank, - model_config.hf_text_config.qk_rope_head_dim) + model_config.hf_text_config.kv_lora_rank, model_config.hf_text_config.qk_rope_head_dim + ) self.cpu_kv_caches = list( self.zmq_rpc_client.call( "init_cpu_kv_caches", @@ -338,7 +304,8 @@ class CPUOffloadingConnectorWorker: self.tp_rank, get_kv_cache_spec(self.vllm_config), mla_config, - ).values()) + ).values() + ) def start_load_kv(self) -> None: self.current_layer = 0 @@ -358,10 +325,8 @@ class CPUOffloadingConnectorWorker: cpu_kv_caches = self.cpu_kv_caches[layer] with torch.npu.stream(self.load_stream): for cpu_block_id, gpu_block_id in self.load_block_mapping: - for gpu_layer_part, cpu_layer_part in zip( - gpu_kv_caches, cpu_kv_caches): - gpu_layer_part[gpu_block_id].copy_( - cpu_layer_part[cpu_block_id], non_blocking=True) + for gpu_layer_part, cpu_layer_part in zip(gpu_kv_caches, cpu_kv_caches): + gpu_layer_part[gpu_block_id].copy_(cpu_layer_part[cpu_block_id], non_blocking=True) def get_finished(self) -> set[str]: done_sending: set[str] = set() @@ -380,8 +345,7 @@ class CPUOffloadingConnectorWorker: self.done_sending_count[req_id] += 1 other_ranks_finished_ids: list[str] = [] for i in range(1, self.tp_world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) + other_ranks_finished_ids.extend(self.tp_group.recv_object(src=i)) for req_id in other_ranks_finished_ids: self.done_sending_count[req_id] += 1 all_done_sending: set[str] = set() @@ -391,8 +355,7 @@ class CPUOffloadingConnectorWorker: all_done_sending.add(req_id) # release cpu_kv_cache after request sending finished # to avoid rpc blocking, use thread to call rpc asynchronously - sending_finished_thread = threading.Thread( - target=self._sending_finished, args=(all_done_sending, )) + sending_finished_thread = threading.Thread(target=self._sending_finished, args=(all_done_sending,)) sending_finished_thread.daemon = True sending_finished_thread.start() @@ -411,11 +374,10 @@ class CPUOffloadingConnectorWorker: while True: req_id, req = self.save_input_queue.get() for i in range( - req.num_cpu_computed_tokens // self.block_size, - min((req.num_computed_tokens + req.num_scheduled_tokens) // - self.block_size, len(req.cpu_block_ids))): - save_block_mapping.append( - (req.gpu_block_ids[i], req.cpu_block_ids[i])) + req.num_cpu_computed_tokens // self.block_size, + min((req.num_computed_tokens + req.num_scheduled_tokens) // self.block_size, len(req.cpu_block_ids)), + ): + save_block_mapping.append((req.gpu_block_ids[i], req.cpu_block_ids[i])) with torch.npu.stream(self.save_stream): # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). # non-MLA: kv_layer is list[tensor], typically means [k, v]. @@ -425,13 +387,9 @@ class CPUOffloadingConnectorWorker: start, step = 0, 1 for i in range(start, len(save_block_mapping), step): gpu_block_id, cpu_block_id = save_block_mapping[i] - for cpu_kv_caches, gpu_kv_caches in zip( - self.cpu_kv_caches, self.gpu_kv_caches.values()): - for cpu_layer_part, gpu_layer_part in zip( - cpu_kv_caches, gpu_kv_caches): - cpu_layer_part[cpu_block_id].copy_( - gpu_layer_part[gpu_block_id], - non_blocking=True) + for cpu_kv_caches, gpu_kv_caches in zip(self.cpu_kv_caches, self.gpu_kv_caches.values()): + for cpu_layer_part, gpu_layer_part in zip(cpu_kv_caches, gpu_kv_caches): + cpu_layer_part[cpu_block_id].copy_(gpu_layer_part[gpu_block_id], non_blocking=True) self.save_stream.synchronize() self.save_output_queue.put(req_id) save_block_mapping.clear() @@ -453,8 +411,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: if vllm_config.cache_config.cache_dtype == "auto": kv_cache_dtype = vllm_config.model_config.dtype else: - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - vllm_config.cache_config.cache_dtype] + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[vllm_config.cache_config.cache_dtype] kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase) @@ -472,10 +429,8 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: # using DSA. Fix the spec in vLLM is the final way. block_size = vllm_config.cache_config.block_size kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=kv_cache_dtype) + block_size=block_size, num_kv_heads=1, head_size=attn_module.head_size, dtype=kv_cache_dtype + ) elif spec := attn_module.get_kv_cache_spec(vllm_config): kv_cache_spec[layer_name] = spec @@ -484,8 +439,7 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: if len(mamba_layers) > 0: if vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") + raise NotImplementedError("Prefix caching is not supported for Mamba yet.") for layer_name, mamba_module in mamba_layers.items(): if spec := mamba_module.get_kv_cache_spec(vllm_config): kv_cache_spec[layer_name] = spec diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py index ab5bc08c..a266c0ff 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/cpu_offload/metadata.py @@ -1,9 +1,10 @@ import math import os import pickle +from collections.abc import Callable from dataclasses import dataclass from multiprocessing.shared_memory import SharedMemory -from typing import Any, Callable, Optional +from typing import Any import torch import vllm.envs as envs @@ -14,8 +15,7 @@ from vllm.utils.network_utils import make_zmq_socket from vllm.utils.torch_utils import get_dtype_size from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec -from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import \ - CPUKVCacheManager +from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.cpu_kv_cache_manager import CPUKVCacheManager @dataclass @@ -30,8 +30,7 @@ def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig: if kv_transfer_config.kv_connector == "CPUOffloadingConnector": return kv_transfer_config elif kv_transfer_config.kv_connector == "MultiConnector": - ktcs = kv_transfer_config.kv_connector_extra_config.get( - "connectors") + ktcs = kv_transfer_config.kv_connector_extra_config.get("connectors") for ktc in ktcs: kv_transfer_config = KVTransferConfig(**ktc) if kv_transfer_config.kv_connector == "CPUOffloadingConnector": @@ -44,7 +43,6 @@ class MetadataServer: DEFAULT_CPU_SWAP_SPACE_GB = 800 class ZMQRPCClient: - def __init__(self, identity=None): if identity is None: identity = f"worker-{os.getpid()}-{id(self)}" @@ -56,7 +54,8 @@ class MetadataServer: zmq.DEALER, # type: ignore bind=False, identity=identity.encode(), - linger=0) + linger=0, + ) def call(self, func_name: str, *args, **kwargs) -> Any: request = (func_name, args, kwargs) @@ -74,11 +73,9 @@ class MetadataServer: self.shared_memory_dict = memory_dict result = {} for key, shm in memory_dict.items(): - tensor = torch.frombuffer( - shm.buf, dtype=layer_dtype).reshape(layer_size) + tensor = torch.frombuffer(shm.buf, dtype=layer_dtype).reshape(layer_size) if mla_config is not None: - tensor = tensor.split( - [mla_config.nope_dim, mla_config.rope_dim], dim=-1) + tensor = tensor.split([mla_config.nope_dim, mla_config.rope_dim], dim=-1) result[key] = tensor return result @@ -86,7 +83,7 @@ class MetadataServer: # will be finalized by outer process self.socket.close() self.ctx.term() - if hasattr(self, 'shared_memory_dict'): + if hasattr(self, "shared_memory_dict"): for shm in self.shared_memory_dict.values(): shm.close() @@ -96,7 +93,8 @@ class MetadataServer: kv_transfer_config = get_cpu_offload_connector(vllm_config) assert kv_transfer_config is not None available_memory_gb = kv_transfer_config.get_from_extra_config( - "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB) + "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB + ) self.available_memory = available_memory_gb * 1024 * 1024 * 1024 logger.info(f"cpu swap space: {self.available_memory} bytes") self.ctx = zmq.Context() # type: ignore @@ -105,7 +103,8 @@ class MetadataServer: MetadataServer.METADATA_SERVER_ADDRESS, zmq.ROUTER, # type: ignore bind=True, - linger=0) + linger=0, + ) self.functions: dict[str, Callable] = { "init_cpu_kv_caches": self.init_cpu_kv_caches, "post_init": self.post_init, @@ -133,15 +132,11 @@ class MetadataServer: tp_rank: int, kv_cache_specs: dict[str, AttentionSpec], mla_config: MLAConfig, - ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, - MLAConfig]: + ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, MLAConfig]: logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}") # follow the assumption that each layer has the same spec layer = next(iter(kv_cache_specs.values())) - assert all([ - layer.page_size_bytes == any.page_size_bytes - for any in kv_cache_specs.values() - ]) + assert all([layer.page_size_bytes == any.page_size_bytes for any in kv_cache_specs.values()]) use_mla = isinstance(layer, MLAAttentionSpec) # mla shares the same kv cache among different tp if use_mla: @@ -154,30 +149,24 @@ class MetadataServer: available_memory //= self.pipeline_parallel_size available_memory //= len(kv_cache_specs) num_blocks = available_memory // layer.page_size_bytes - layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, - layer.head_size) # type: ignore + layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore else: available_memory //= self.world_size available_memory //= len(kv_cache_specs) num_blocks = available_memory // layer.page_size_bytes - layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, - layer.head_size) # type: ignore + layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, layer.head_size) # type: ignore nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype) - for layer_name in kv_cache_specs.keys(): + for layer_name in kv_cache_specs: # only this format can share during ZeroMQ+pickle - shared_memory_dict[ - layer_name] = MetadataServer._safe_create_shared_memory( - f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) + shared_memory_dict[layer_name] = MetadataServer._safe_create_shared_memory( + f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes + ) if use_mla: assert mla_config is not None assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim - self.shared_memory[(pp_rank, - tp_rank)] = (shared_memory_dict, layer_size, - layer.dtype, mla_config) + self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, mla_config) else: - self.shared_memory[(pp_rank, - tp_rank)] = (shared_memory_dict, layer_size, - layer.dtype, None) + self.shared_memory[(pp_rank, tp_rank)] = (shared_memory_dict, layer_size, layer.dtype, None) if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks: self.num_cpu_blocks = num_blocks self.layer = layer @@ -185,23 +174,20 @@ class MetadataServer: def post_init(self): # different processors in data parallel may call multiple times - if hasattr(self, 'cpu_block_manager'): + if hasattr(self, "cpu_block_manager"): return # do shared_memory() at least once logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}") assert self.num_cpu_blocks >= 0 - self.cpu_block_manager = CPUKVCacheManager(self.layer, - self.num_cpu_blocks) - self.functions.update({ - "get_matched_num_and_touch": - self.cpu_block_manager.get_matched_num_and_touch, - "allocate_slots": - self.cpu_block_manager.allocate_slots, - "record_request_cache_and_free_slots": - self.cpu_block_manager.record_request_cache_and_free_slots, - "cache_and_free_slots": - self.cpu_block_manager.cache_and_free_slots, - }) + self.cpu_block_manager = CPUKVCacheManager(self.layer, self.num_cpu_blocks) + self.functions.update( + { + "get_matched_num_and_touch": self.cpu_block_manager.get_matched_num_and_touch, + "allocate_slots": self.cpu_block_manager.allocate_slots, + "record_request_cache_and_free_slots": self.cpu_block_manager.record_request_cache_and_free_slots, + "cache_and_free_slots": self.cpu_block_manager.cache_and_free_slots, + } + ) def serve_step(self): client_id = self.socket.recv() @@ -228,8 +214,7 @@ class MetadataServer: def shutdown(self): self.socket.close() self.ctx.term() - socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace( - "ipc://", "") + socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace("ipc://", "") if os.path.exists(socket_path): os.remove(socket_path) for cached in self.shared_memory.values(): @@ -239,11 +224,9 @@ class MetadataServer: class MetadataServerProc: - @staticmethod def run_metadata_server(vllm_config: VllmConfig): - if (not vllm_config.cache_config.enable_prefix_caching - or get_cpu_offload_connector(vllm_config) is None): + if not vllm_config.cache_config.enable_prefix_caching or get_cpu_offload_connector(vllm_config) is None: return shutdown_requested = False @@ -257,7 +240,7 @@ class MetadataServerProc: # Either SIGTERM or SIGINT will terminate the worker # signal.signal(signal.SIGTERM, _signal_handler) # signal.signal(signal.SIGINT, _signal_handler) - metadata_server: Optional[MetadataServer] = None + metadata_server: MetadataServer | None = None try: metadata_server = MetadataServer(vllm_config) logger.info("Metadata server started.") diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py index a4c62681..550d98ac 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ucm_connector.py @@ -4,19 +4,21 @@ from typing import TYPE_CHECKING, Any, Optional import torch from ucm.integration.vllm.ucm_connector import UCMConnector from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput - logger = init_logger(__name__) # isort: off if TYPE_CHECKING: from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorPromMetrics, KVConnectorStats, PromMetric, PromMetricT) + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, + ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig @@ -25,16 +27,13 @@ if TYPE_CHECKING: class UCMConnectorV1(KVConnectorBase_V1): - def __init__( self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig", ): - super().__init__(vllm_config=vllm_config, - role=role, - kv_cache_config=kv_cache_config) + super().__init__(vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config) assert vllm_config.kv_transfer_config is not None ImplCls = UCMConnector @@ -60,8 +59,7 @@ class UCMConnectorV1(KVConnectorBase_V1): """ self._ucm_engine.register_kv_caches(kv_caches) - def start_load_kv(self, forward_context: "ForwardContext", - **kwargs: Any) -> None: + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the @@ -110,8 +108,7 @@ class UCMConnectorV1(KVConnectorBase_V1): attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ - self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, - **kwargs) + self._ucm_engine.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) def wait_for_save(self) -> None: """ @@ -131,8 +128,7 @@ class UCMConnectorV1(KVConnectorBase_V1): """ self._ucm_engine.clear_connector_metadata() - def bind_connector_metadata( - self, connector_metadata: KVConnectorMetadata) -> None: + def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: """Set the connector metadata from the scheduler. This function should be called by the model runner every time @@ -175,20 +171,15 @@ class UCMConnectorV1(KVConnectorBase_V1): the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ - return self._ucm_engine.get_num_new_matched_tokens( - request, num_computed_tokens) + return self._ucm_engine.get_num_new_matched_tokens(request, num_computed_tokens) - def update_state_after_alloc(self, request: "Request", - blocks: "KVCacheBlocks", - num_external_tokens: int) -> None: + def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int) -> None: """ Update KVConnector state after block allocation. """ - self._ucm_engine.update_state_after_alloc(request, blocks, - num_external_tokens) + self._ucm_engine.update_state_after_alloc(request, blocks, num_external_tokens) - def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -222,10 +213,7 @@ class UCMConnectorV1(KVConnectorBase_V1): # ============================== @classmethod - def build_kv_connector_stats( - cls, - data: dict[str, Any] | None = None - ) -> Optional["KVConnectorStats"]: + def build_kv_connector_stats(cls, data: dict[str, Any] | None = None) -> Optional["KVConnectorStats"]: """ KVConnectorStats resolution method. This method allows dynamically registered connectors to return their own KVConnectorStats object, diff --git a/vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py b/vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py index d1423c28..3527f9cd 100644 --- a/vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py +++ b/vllm_ascend/distributed/kv_transfer/utils/mooncake_transfer_engine.py @@ -1,19 +1,16 @@ -import ipaddress import threading -from typing import Optional from mooncake.engine import TransferEngine # type: ignore -class GlobalTE(): - +class GlobalTE: def __init__(self): self.transfer_engine = None self.is_register_buffer: bool = False self.transfer_engine_lock = threading.Lock() self.register_buffer_lock = threading.Lock() - def get_transfer_engine(self, hostname: str, device_name: Optional[str]): + def get_transfer_engine(self, hostname: str, device_name: str | None): if self.transfer_engine is None: with self.transfer_engine_lock: # Double-Checked Locking @@ -22,12 +19,9 @@ class GlobalTE(): raise RuntimeError("mooncake is not available") self.transfer_engine = TransferEngine() device_name = device_name if device_name is not None else "" - ret_value = self.transfer_engine.initialize( - hostname, "P2PHANDSHAKE", "ascend", device_name) + ret_value = self.transfer_engine.initialize(hostname, "P2PHANDSHAKE", "ascend", device_name) if ret_value != 0: - raise RuntimeError( - f"TransferEngine initialization failed with ret_value: {ret_value}" - ) + raise RuntimeError(f"TransferEngine initialization failed with ret_value: {ret_value}") return self.transfer_engine def register_buffer(self, ptrs: list[int], sizes: list[int]): diff --git a/vllm_ascend/distributed/kv_transfer/utils/utils.py b/vllm_ascend/distributed/kv_transfer/utils/utils.py index c25c1f15..4886f4ac 100644 --- a/vllm_ascend/distributed/kv_transfer/utils/utils.py +++ b/vllm_ascend/distributed/kv_transfer/utils/utils.py @@ -6,8 +6,7 @@ import torch.distributed as dist from vllm_ascend.distributed.parallel_state import get_p_tp_group -def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, - value: torch.TensorType): +def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, value: torch.TensorType): if pd_tp_ratio <= 1: return None, None elif key is None or value is None: @@ -20,22 +19,17 @@ def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor, def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor): num_kv_heads = input_tensor.size(1) output_tensor = torch.zeros_like(input_tensor) - dist.all_to_all_single(output_tensor, - input_tensor, - group=get_p_tp_group().device_group) + dist.all_to_all_single(output_tensor, input_tensor, group=get_p_tp_group().device_group) input_tensor = 0 result = rearrange_output(output_tensor, tp_ratio, num_kv_heads) output_tensor = 0 return result -def rearrange_output(base_output: torch.Tensor, cut_num: int, - num_kv_heads: int): +def rearrange_output(base_output: torch.Tensor, cut_num: int, num_kv_heads: int): size_0 = base_output.size(0) if size_0 % cut_num != 0: - raise ValueError( - f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]" - ) + raise ValueError(f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]") chunk_size = size_0 // cut_num reshaped = base_output.view(cut_num, chunk_size, -1) transposed = reshaped.transpose(0, 1) @@ -46,16 +40,13 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor: data_ptr = tensor.data_ptr() aligned_addr = (data_ptr + alignment - 1) // alignment * alignment offset = (aligned_addr - data_ptr) // tensor.element_size() - return tensor[int(offset):] + return tensor[int(offset) :] def get_transfer_timeout_value(): ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "") if len(ascend_transfer_timeout) > 0: return int(ascend_transfer_timeout) - hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT', - '20')) # type: ignore - hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT', - '7')) # type: ignore - return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + - 3000) + hccl_rdma_timeout = int(os.getenv("HCCL_RDMA_TIMEOUT", "20")) # type: ignore + hccl_rdma_retry_cnt = int(os.getenv("HCCL_RDMA_RETRY_CNT", "7")) # type: ignore + return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 + 3000) diff --git a/vllm_ascend/kv_offload/cpu_npu.py b/vllm_ascend/kv_offload/cpu_npu.py index 13e9869d..76b1926f 100644 --- a/vllm_ascend/kv_offload/cpu_npu.py +++ b/vllm_ascend/kv_offload/cpu_npu.py @@ -4,8 +4,7 @@ from vllm.logger import init_logger from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backend import AttentionBackend # type: ignore from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec -from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, - TransferResult, TransferSpec) +from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec logger = init_logger(__name__) @@ -44,7 +43,6 @@ def expand_block_ids( class CpuNpuOffloadingHandler(OffloadingHandler): - def __init__( self, gpu_block_size: int, @@ -81,20 +79,22 @@ class CpuNpuOffloadingHandler(OffloadingHandler): cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor logger.debug("Allocating CPU tensor of shape %r", cpu_shape) - self.cpu_tensors.append(( - torch.zeros( - cpu_shape, - dtype=gpu_tensor[0].dtype, - device="cpu", - pin_memory=pin_memory, - ), - torch.zeros( - cpu_shape, - dtype=gpu_tensor[0].dtype, - device="cpu", - pin_memory=pin_memory, - ), - )) + self.cpu_tensors.append( + ( + torch.zeros( + cpu_shape, + dtype=gpu_tensor[0].dtype, + device="cpu", + pin_memory=pin_memory, + ), + torch.zeros( + cpu_shape, + dtype=gpu_tensor[0].dtype, + device="cpu", + pin_memory=pin_memory, + ), + ) + ) def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: logger.info("start transfer_async...") @@ -123,9 +123,7 @@ class CpuNpuOffloadingHandler(OffloadingHandler): dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor src_sub_block_count = src_blocks.size * src_block_size_factor - assert ( - src_sub_block_count == dst_blocks.size * dst_block_size_factor - - dst_sub_blocks_to_skip) + assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) @@ -137,18 +135,14 @@ class CpuNpuOffloadingHandler(OffloadingHandler): ) src_to_dst_tensor = torch.from_numpy(src_to_dst) - event = self.events_pool.pop( - ) if self.events_pool else torch.npu.Event() + event = self.events_pool.pop() if self.events_pool else torch.npu.Event() with torch.npu.stream(stream): for src_tensor, dst_tensor in zip(src_tensors, dst_tensors): src_key_cache, src_value_cache = src_tensor[0], src_tensor[1] dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1] - torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, - src_to_dst_tensor) - torch.ops._C_ascend.swap_blocks(src_value_cache, - dst_value_cache, - src_to_dst_tensor) + torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) + torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) event.record(stream) @@ -175,4 +169,4 @@ class CpuNpuOffloadingHandler(OffloadingHandler): event = self.transfer_events.get(job_id) if event is not None: # This will block until the NPU event is complete - event.synchronize() \ No newline at end of file + event.synchronize() diff --git a/vllm_ascend/kv_offload/npu.py b/vllm_ascend/kv_offload/npu.py index 3c5d8ae0..211f3dce 100644 --- a/vllm_ascend/kv_offload/npu.py +++ b/vllm_ascend/kv_offload/npu.py @@ -1,48 +1,40 @@ from collections.abc import Iterator -from typing import Optional import torch from vllm.config import VllmConfig from vllm.v1.attention.backend import AttentionBackend # type: ignore +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.backends.cpu import CPUBackend from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.spec import OffloadingSpec from vllm.v1.kv_offload.worker.worker import OffloadingHandler -from vllm.v1.kv_cache_interface import KVCacheConfig from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler class NPUOffloadingSpec(OffloadingSpec): - - def __init__(self, - vllm_config: VllmConfig, - kv_cache_config: Optional[KVCacheConfig] = None): + def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig | None = None): super().__init__(vllm_config, kv_cache_config) num_cpu_blocks = self.extra_config.get("num_cpu_blocks") if not num_cpu_blocks: - raise Exception( - "num_cpu_blocks must be specified in kv_connector_extra_config" - ) + raise Exception("num_cpu_blocks must be specified in kv_connector_extra_config") self.num_cpu_blocks: int = num_cpu_blocks # scheduler-side - self._manager: Optional[OffloadingManager] = None + self._manager: OffloadingManager | None = None # worker-side - self._handler: Optional[OffloadingHandler] = None + self._handler: OffloadingHandler | None = None def get_manager(self) -> OffloadingManager: if not self._manager: kv_events_config = self.vllm_config.kv_events_config - enable_events = (kv_events_config is not None - and kv_events_config.enable_kv_cache_events) + enable_events = kv_events_config is not None and kv_events_config.enable_kv_cache_events self._manager = LRUOffloadingManager( - CPUBackend(block_size=self.offloaded_block_size, - num_blocks=self.num_cpu_blocks), + CPUBackend(block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks), enable_events=enable_events, ) return self._manager @@ -51,8 +43,7 @@ class NPUOffloadingSpec(OffloadingSpec): self, kv_caches: dict[str, torch.Tensor], attn_backends: dict[str, type[AttentionBackend]], - ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], - OffloadingHandler]]: + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: self._handler = CpuNpuOffloadingHandler( attn_backends=attn_backends, diff --git a/vllm_ascend/lora/lora_ops.py b/vllm_ascend/lora/lora_ops.py index 58d0ea60..e63e6783 100644 --- a/vllm_ascend/lora/lora_ops.py +++ b/vllm_ascend/lora/lora_ops.py @@ -16,11 +16,13 @@ import torch -def bgmv_shrink(inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): +def bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): return torch.ops._C_ascend.bgmv_shrink( inputs, lora_a_weights, @@ -30,11 +32,13 @@ def bgmv_shrink(inputs: torch.Tensor, ) -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): return torch.ops._C_ascend.bgmv_expand( inputs, lora_b_weights, @@ -45,16 +49,18 @@ def bgmv_expand(inputs: torch.Tensor, ) -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): - return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, output_tensor, - slice_offset, slice_size) +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + return torch.ops._C_ascend.bgmv_expand( + inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset, slice_size + ) def sgmv_shrink( @@ -69,21 +75,23 @@ def sgmv_shrink( token_nums: int, scaling: float, ): - return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, scaling) + return torch.ops._C_ascend.sgmv_shrink( + inputs, lora_a_weights, lora_indices_tensor, seq_len_tensor, output_tensor, scaling + ) -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False): +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +): return torch.ops._C_ascend.sgmv_expand( inputs, lora_b_weights, @@ -95,19 +103,20 @@ def sgmv_expand(inputs: torch.Tensor, ) -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights, - lora_indices_tensor, seq_len_tensor, - output_tensor, slice_offset, - slice_size) +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +): + return torch.ops._C_ascend.sgmv_expand( + inputs, lora_b_weights, lora_indices_tensor, seq_len_tensor, output_tensor, slice_offset, slice_size + ) diff --git a/vllm_ascend/lora/punica_npu.py b/vllm_ascend/lora/punica_npu.py index 90c2ef5d..885c0765 100644 --- a/vllm_ascend/lora/punica_npu.py +++ b/vllm_ascend/lora/punica_npu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Tuple, Union +from collections.abc import Callable import torch from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase @@ -18,26 +18,30 @@ class PunicaWrapperNPU(PunicaWrapperBase): Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, - device) + def __init__(self, max_num_batched_tokens: int, max_batches: int, device: torch.device | str, **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) refresh_all_lora_classes() self.lora_config = kwargs.get("lora_config") if get_ascend_device_type() == AscendDeviceType._310P or ( - self.lora_config is not None - and self.lora_config.max_lora_rank >= 128): - from vllm.lora.ops.torch_ops import (bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, - sgmv_shrink) + self.lora_config is not None and self.lora_config.max_lora_rank >= 128 + ): + from vllm.lora.ops.torch_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, + ) else: - from vllm_ascend.lora.lora_ops import (bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, - sgmv_shrink) + from vllm_ascend.lora.lora_ops import ( + bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, + sgmv_expand, + sgmv_expand_slice, + sgmv_shrink, + ) self.bgmv_expand = bgmv_expand self.bgmv_expand_slice = bgmv_expand_slice self.bgmv_shrink = bgmv_shrink @@ -52,7 +56,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): w_t_all: torch.Tensor, scale: float, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return self.sgmv_shrink( @@ -79,7 +83,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): w_t_all: torch.Tensor, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return self.sgmv_expand( @@ -108,7 +112,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size: int, add_inputs: bool, ): - #No LoRA request, so return directly + # No LoRA request, so return directly if self.no_lora: return self.sgmv_expand_slice( @@ -130,8 +134,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y_slice_size: int, add_inputs: bool, ): - self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, - y_offset, y_slice_size, add_inputs) + self.bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) def _apply_expand( self, @@ -148,13 +151,10 @@ class PunicaWrapperNPU(PunicaWrapperBase): GEMM of lora'b. """ - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) + expand_slice_fun: Callable = self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float): """ Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. @@ -165,14 +165,18 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ y_org = y y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) + shrink_fun: Callable = self._shrink_prefill if self.is_prefill else self._shrink_decode shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs): + def add_shrink( + self, + y: tuple[torch.Tensor, ...] | torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + scale: float, + **kwargs, + ): """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -194,18 +198,19 @@ class PunicaWrapperNPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale) - def add_expand(self, - y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs) -> None: + def add_expand( + self, + y: torch.Tensor, + x: tuple[torch.Tensor, ...] | torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: tuple[torch.Tensor, ...] | None, + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs, + ) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -229,8 +234,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): y = y.view(-1, y.shape[-1]) offset_left = offset_start if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -243,12 +247,9 @@ class PunicaWrapperNPU(PunicaWrapperBase): offset_left += output_slices[slice_idx] y = y.view_as(y_org) - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: + def add_lora_embedding( + self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs + ) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -263,21 +264,22 @@ class PunicaWrapperNPU(PunicaWrapperBase): """ # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) + expand_fun: Callable = self._expand_prefill if self.is_prefill else self._expand_decode x = x.to(torch.float32) expand_fun(y, x, lora_b_stacked, add_inputs) - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - scale: float, - output_slices: Tuple[int, ...], - *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: tuple[torch.Tensor, ...] | None = None, + **kwargs, + ) -> None: """ Applicable to linear-related lora. @@ -308,27 +310,22 @@ class PunicaWrapperNPU(PunicaWrapperBase): # We set the buffer to be float32 by default, consistent with the # triton op buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) for _ in range(len(output_slices)) + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + self.add_expand(y, buffer, lora_b_stacked, None, output_slices, add_inputs=True, **kwargs) - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + def add_lora_logits( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: torch.Tensor | None = None, + **kwargs, + ) -> None: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -350,9 +347,7 @@ class PunicaWrapperNPU(PunicaWrapperBase): r = lora_b_stacked.size(-1) if buffer is None: - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) indices = self.sampler_indices diff --git a/vllm_ascend/lora/utils.py b/vllm_ascend/lora/utils.py index be4fbebe..a0178560 100644 --- a/vllm_ascend/lora/utils.py +++ b/vllm_ascend/lora/utils.py @@ -1,91 +1,75 @@ -from typing import Optional - import vllm from torch import nn from transformers import PretrainedConfig from vllm.config import LoRAConfig -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA, +) from vllm.lora.layers.utils import _not_fully_sharded_can_replace -from vllm_ascend.ops.linear import (AscendColumnParallelLinear, - AscendMergedColumnParallelLinear, - AscendQKVParallelLinear, - AscendRowParallelLinear) -from vllm_ascend.ops.vocab_parallel_embedding import \ - AscendVocabParallelEmbedding +from vllm_ascend.ops.linear import ( + AscendColumnParallelLinear, + AscendMergedColumnParallelLinear, + AscendQKVParallelLinear, + AscendRowParallelLinear, +) +from vllm_ascend.ops.vocab_parallel_embedding import AscendVocabParallelEmbedding class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendColumnParallelLinear -class AscendMergedColumnParallelLinearWithLoRA( - MergedColumnParallelLinearWithLoRA): - +class AscendMergedColumnParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendMergedColumnParallelLinear class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA): - @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendRowParallelLinear class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA): - @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: return type(source_layer) is AscendVocabParallelEmbedding class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA): - - @classmethod - @_not_fully_sharded_can_replace - def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is AscendQKVParallelLinear and len( - packed_modules_list) == 1 - - -class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( @@ -93,18 +77,28 @@ class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, - model_config: Optional[PretrainedConfig], + model_config: PretrainedConfig | None, ) -> bool: - return (type(source_layer) is AscendQKVParallelLinear - and len(packed_modules_list) == 3) + return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 1 + + +class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA): + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: list, + model_config: PretrainedConfig | None, + ) -> bool: + return type(source_layer) is AscendQKVParallelLinear and len(packed_modules_list) == 3 def refresh_all_lora_classes(): vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA) - vllm.lora.utils._all_lora_classes.add( - AscendMergedColumnParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendMergedColumnParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA) vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA) - vllm.lora.utils._all_lora_classes.add( - AscendMergedQKVParallelLinearWithLoRA) + vllm.lora.utils._all_lora_classes.add(AscendMergedQKVParallelLinearWithLoRA)