[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase,
)
class MLUCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = ""
):
super().__init__(cpu_group, device, device_group, unique_name)
# init device according to rank
self.device = torch.mlu.current_device()
self.ca_comm: CustomAllreduce | None = None

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
MLUKVConnectors: dict[str, tuple[str, str]] = {
"MLUSharedStorageConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector"
),
"MLUNixlConnector": (
"vllm_mlu.distributed.kv_transfer.kv_connector.v1.nixl_connector",
"MLUNixlConnector"
),
}
for name, (module_path, class_name) in MLUKVConnectors.items():
if name not in KVConnectorFactory._registry:
KVConnectorFactory.register_connector(name, module_path, class_name)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class LMCacheConnectorV1_MluHijack(LMCacheConnectorV1):
def response_remote_alloc_once(self) -> None:
self._lmcache_engine.response_remote_alloc_once()
def request_remote_memory_send(self) -> None:
self._lmcache_engine.request_remote_memory_send()
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"response_remote_alloc_once",
LMCacheConnectorV1_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(LMCacheConnectorV1,
"request_remote_memory_send",
LMCacheConnectorV1_MluHijack.request_remote_memory_send)

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import math
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
import zmq
from vllm import envs
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
EngineId, NixlConnectorWorker, NixlAgentMetadata, NixlConnectorScheduler, NixlConnector)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
class MLUNixlConnector(NixlConnector):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super(NixlConnector, self).__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler : MLUNixlConnectorScheduler | None = (
MLUNixlConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MLUNixlConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MLUNixlConnectorWorker(vllm_config, self.engine_id)
class MLUNixlConnectorScheduler(NixlConnectorScheduler):
"""Implementation of Scheduler side methods"""
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: kv transfer info
'''
if request.kv_transfer_params.get("do_remote_prefill", False):
logger.info(f"NIXLConnector update_state_after_alloc: request_id={request.request_id}, "
f"num_prompt_tokens={request.num_prompt_tokens}, "
f"num_external_tokens={num_external_tokens}, "
f"kv_transfer_params={request.kv_transfer_params}")
'''
==================
End of MLU Hijack
==================
'''
params = request.kv_transfer_params
logger.debug(
"NIXLConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# save all blocks
block_ids = blocks.get_block_ids()[0]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if block_ids:
self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids()
if num_external_tokens > 0
else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
else:
assert num_external_tokens == 0
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
class MLUNixlConnectorWorker(NixlConnectorWorker):
"""Implementation of Worker side methods"""
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
'''
=============================
Add by vllm_mlu
=============================
@brief: not support kv8
'''
if not isinstance(first_kv_cache, torch.Tensor):
kv_caches = {key: value[0] for key, value in kv_caches.items()}
_, first_kv_cache = next(iter(kv_caches.items()))
'''
==================
End of MLU Hijack
==================
'''
kv_elem_size = first_kv_cache.element_size()
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
use_mla = first_kv_cache.shape[0] == 1
'''
==================
End of MLU Hijack
==================
'''
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if use_mla:
# MLA case.
'''
=============================
Add by vllm_mlu
=============================
@brief: support mla
'''
self.num_blocks = first_kv_cache.shape[1]
'''
==================
End of MLU Hijack
==================
'''
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
'''
=============================
Add by vllm_mlu
=============================
@brief: MLU kv_cache layout is [2 (k and v), num_blocks, kv_heads, block_size, head_dim]
'''
n_kv_heads, block_size, head_dim = block_shape[-3:]
'''
==================
End of MLU Hijack
==================
'''
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
"block_shape: %s, per_layer_kv_cache_shape: %s", use_mla,
self.num_blocks, block_shape, first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
# elegantly support MLA and any cases where the K and V tensors
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla or self._use_flashinfer \
else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
caches_data.append(
(base_addr, region_len, cache.device.index, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
self.num_layers = len(self.kv_caches.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")
self._registered_descs.append(descs)
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
ready_event.wait()

View File

@@ -0,0 +1,450 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm_mlu.v1.attention.backends.flash_mla import MLAFlashAttentionCommonMetadata
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Is store or load
is_store: bool
mm_hashes: list[str]
@staticmethod
def make_meta(
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> "ReqMeta":
valid_num_tokens = align_to_block_size(len(token_ids), block_size)
token_ids_tensor = torch.tensor(token_ids)[:valid_num_tokens]
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
is_store=is_store,
mm_hashes=mm_hashes,
)
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
token_ids: list[int],
block_ids: list[int],
block_size: int,
is_store: bool,
mm_hashes: list[str],
) -> None:
self.requests.append(
ReqMeta.make_meta(token_ids, block_ids, block_size, is_store, mm_hashes)
)
class SharedStorageConnector(KVConnectorBase_V1):
# NOTE: This is Simple debug implementation of the KV connector.
# It save / load the KV cache to / from the disk.
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp")
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1
)
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = self._get_connector_metadata()
assert isinstance(metadata, SharedStorageConnectorMetadata)
if metadata is None:
logger.warning(
"In connector.start_load_kv, but the connector metadata is None"
)
return
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
logger.warning("In connector.start_load_kv, but the attn_metadata is None")
return
# Load the KV for each request each layer
for request in metadata.requests:
if request.is_store:
continue
logger.info(
"Inject KV cache of %d tokens to the paged memory",
len(request.slot_mapping),
)
for layer_name in forward_context.no_compile_layers:
layer = forward_context.no_compile_layers[layer_name]
# Only process layers that have kv_cache
# attribute (attention layers) Skip non-attention
# layers like FusedMoE/MLP etc.
kv_cache_attr = getattr(layer, "kv_cache", None)
if kv_cache_attr is None:
continue
kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLAFlashAttentionCommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, SharedStorageConnectorMetadata)
for request in connector_metadata.requests:
if request.is_store:
filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
tensors = {"kv_cache": kv_cache.detach().cpu()}
safetensors.torch.save_file(tensors, filename)
def wait_for_save(self):
return
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# NOTE: in this debug implementation, we assume that the prompt is
# cached_prompt + newly_generated_single_token
# Therefore, we use prompt_token_ids[:-1] to determine the folder name
# NOTE: in current v1 scheduler, the num_computed_tokens is aligned
# with the block granularity. And it expects the returned blocks and
# num_computed_tokens to also be aligned with the block granularity.
if not self._found_match_for_request(request):
return 0, False
logger.info("External Cache Hit!")
# Now, first num_tokens_to_check tokens are hit, we need to prepare
# the metadata for the worker connector to correctly load the KV
token_ids = request.prompt_token_ids or []
num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size)
return num_tokens_to_check - num_computed_tokens, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
if num_external_tokens > 0:
self._requests_need_load[request.request_id] = request
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = SharedStorageConnectorMetadata()
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
mm_hashes = [f.identifier for f in new_req.mm_features]
if new_req.req_id in self._requests_need_load:
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=mm_hashes,
)
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
# but a single request can have both store and load.
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_prompt(token_ids, mm_hashes):
meta.add_request(
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=mm_hashes,
)
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = req_id in cached_reqs.resumed_req_ids
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()
return meta
# ==============================
# Helper functions
# ==============================
def _found_match_for_request(
self,
request: "Request",
) -> bool:
"""Check if the cache is hit for the request."""
return self._found_match_for_prompt(
list(request.prompt_token_ids or []),
[f.identifier for f in request.mm_features],
)
def _found_match_for_prompt(
self,
prompt_token_ids: list[int],
mm_hashes: list[str],
) -> bool:
num_tokens_to_check = align_to_block_size(
len(prompt_token_ids) - 1, self._block_size
)
foldername = self._generate_foldername_debug(
torch.tensor(prompt_token_ids)[:num_tokens_to_check],
mm_hashes,
create_folder=False,
)
return os.path.exists(foldername)
def _generate_foldername_debug(
self,
token_ids: torch.Tensor,
mm_hashes: list[str],
create_folder=False,
) -> str:
"""Generate a folder name based on the hash of the bytes of the input
ids.
"""
token_bytes = token_ids.numpy().tobytes()
# Add mm_hashes to the bytes being hashed to avoid path traversal and
# to create a canonical key.
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:
os.makedirs(foldername, exist_ok=True)
return foldername
def _generate_filename_debug(
self,
layer_name: str,
token_ids: torch.Tensor,
mm_hashes: list[str],
) -> str:
"""Generate a file name based on the layer name and the hash
of the bytes of the input ids.
"""
foldername = self._generate_foldername_debug(
token_ids, mm_hashes=mm_hashes, create_folder=True
)
return os.path.join(foldername, f"{layer_name}.safetensors")
def align_to_block_size(num_tokens: int, block_size) -> int:
"""Align the number of tokens to the block size."""
return (num_tokens - 1) // block_size * block_size

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from contextlib import contextmanager, nullcontext
from typing import Optional
from dataclasses import dataclass
import torch
from vllm.distributed.parallel_state import (
GroupCoordinator,
GraphCaptureContext,
get_pp_group,
get_tp_group,
)
from vllm.distributed.mlu_parallel_state import(
get_moe_expert_parallel_world_size,
get_moe_expert_parallel_rank,
get_moe_expert_parallel_group,
)
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
@dataclass
class MLUGraphCaptureContext:
stream: torch.mlu.Stream
@contextmanager
def mlu_graph_capture(device: torch.device):
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
context = MLUGraphCaptureContext(torch.mlu.Stream(device=device))
with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
yield context
@contextmanager
def vllm__distributed__parallel_state__GroupCoordinator__graph_capture(
self,
graph_capture_context: GraphCaptureContext | None = None,
):
if graph_capture_context is None:
stream = torch.mlu.Stream()
graph_capture_context = GraphCaptureContext(stream)
else:
stream = graph_capture_context.stream
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
from vllm_mlu.distributed.device_communicators.mlu_communicator import (
MLUCommunicator,
)
if self.device_communicator is not None:
assert isinstance(self.device_communicator, MLUCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.mlu.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.mlu.stream(stream), maybe_ca_context:
yield graph_capture_context
@dataclass
class CnclEPBuffer:
dispatch_send_token_tensor: torch.Tensor
dispatch_recv_token_tensor: torch.Tensor
combine_send_token_tensor: torch.Tensor
combine_recv_token_tensor: torch.Tensor
class CnclEP:
def __init__(self,
dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True) -> None:
nranks = get_moe_expert_parallel_world_size()
rank = get_moe_expert_parallel_rank()
moe_ep_group = get_moe_expert_parallel_group()
self.max_num_tokens_per_rank = max_num_tokens_per_rank
self.use_quant_dispatch = use_quant_dispatch
(
handle,
exchange_info_size,
exchange_info,
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor
) = mlu_ops.moe_all2all_create(dispatch_token_size,
combine_token_size,
num_global_experts,
max_num_tokens_per_rank,
rank,
nranks)
self.handle = handle
self.buffer = CnclEPBuffer(
dispatch_send_token_tensor,
dispatch_recv_token_tensor,
combine_send_token_tensor,
combine_recv_token_tensor)
assert exchange_info.ndim == 1, "exchange_info should be 1D"
all_exchange_info = torch.empty((nranks, exchange_info.size(0)),
dtype=exchange_info.dtype,
device=exchange_info.device)
exchange_info = exchange_info.unsqueeze(0)
torch.distributed.all_gather_into_tensor(all_exchange_info,
exchange_info,
group=moe_ep_group.cpu_group,
async_op=False)
mlu_ops.moe_all2all_init(self.handle, all_exchange_info)
torch.distributed.barrier(group=moe_ep_group.cpu_group)
def dispatch(self,
token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) -> None:
'''
The returned tensors are in-placed modified, we could directly use them
after dispatch finishes.
'''
mlu_ops.moe_all2all_dispatch(self.handle,
token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def combine(self,
token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
) ->None:
mlu_ops.moe_all2all_combine(self.handle,
token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy(self) -> None:
mlu_ops.moe_all2all_destroy(self.handle)
_CNCLEP: CnclEP | None = None
_CNCLEP_BF16: CnclEP | None = None
def get_cnclep(use_quant_dispatch: bool = True) -> CnclEP:
if use_quant_dispatch:
assert _CNCLEP is not None, "cnclep is not initialized"
return _CNCLEP
else:
assert _CNCLEP_BF16 is not None, "cnclep_bf16 is not initialized"
return _CNCLEP_BF16
def init_cnclep(dispatch_token_size: int,
combine_token_size: int,
max_num_tokens_per_rank: int,
num_global_experts: int,
use_quant_dispatch: bool = True):
if use_quant_dispatch:
global _CNCLEP
assert _CNCLEP is None, "cnclep has been initialized"
_CNCLEP = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
else:
global _CNCLEP_BF16
assert _CNCLEP_BF16 is None, "cnclep_bf16 has been initialized"
_CNCLEP_BF16 = CnclEP(dispatch_token_size,
combine_token_size,
max_num_tokens_per_rank,
num_global_experts,
use_quant_dispatch)
def cnclep_dispatch(token_byte: int,
token_num: int,
send_layout: torch.Tensor,
send_token_num: torch.Tensor,
recv_layout: torch.Tensor,
recv_token_num: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
else:
_CNCLEP_BF16.dispatch(token_byte,
token_num,
send_layout,
send_token_num,
recv_layout,
recv_token_num,
send_token,
recv_token)
def cnclep_combine(token_byte: int,
token_num: int,
send_src_layout: torch.Tensor,
send_dst_layout: torch.Tensor,
send_token: Optional[torch.Tensor] = None,
recv_token: Optional[torch.Tensor] = None,
use_quant_dispatch: bool = True,
):
if use_quant_dispatch:
_CNCLEP.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
else:
_CNCLEP_BF16.combine(token_byte,
token_num,
send_src_layout,
send_dst_layout,
send_token,
recv_token)
def destroy_cnclep():
global _CNCLEP
if _CNCLEP:
_CNCLEP.destroy()
_CNCLEP = None
global _CNCLEP_BF16
if _CNCLEP_BF16:
_CNCLEP_BF16.destroy()
_CNCLEP_BF16 = None
MluHijackObject.apply_hijack(GroupCoordinator,
GroupCoordinator.graph_capture,
vllm__distributed__parallel_state__GroupCoordinator__graph_capture)