[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

3
vllm_mlu/v1/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,404 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from collections import OrderedDict, deque
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.attention.backends.gdn_attn import (GDNAttentionMetadataBuilder,
GDNAttentionMetadata,
)
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,)
class DeviceAwareLocalIdMapper:
def __init__(self, batch_size: int):
if batch_size <= 0:
raise ValueError("batch_size must be positive")
self.batch_size = batch_size
self.global_to_local: OrderedDict[int, int] = OrderedDict()
self.local_to_global = {}
self.available_local_ids = deque(range(batch_size))
def batch_get_local_ids(self, global_id_tensor: torch.Tensor) -> torch.Tensor:
original_device = global_id_tensor.device
original_shape = global_id_tensor.shape
flat_global_cpu = global_id_tensor.cpu().numpy().ravel()
num_elements = flat_global_cpu.size
local_ids_cpu = torch.empty(num_elements, dtype=global_id_tensor.dtype)
g2l = self.global_to_local
unique_miss_set = set()
# Pass 1: handle hits and collect unique misses
for i, gid in enumerate(flat_global_cpu):
if gid in g2l:
local_id = g2l[gid]
local_ids_cpu[i] = local_id
g2l.move_to_end(gid)
else:
local_ids_cpu[i] = -1
unique_miss_set.add(gid)
# Pass 2: assign local IDs to unique new global IDs
new_mappings = {}
available = self.available_local_ids
local_to_global = self.local_to_global
for gid in unique_miss_set:
if len(g2l) >= self.batch_size:
old_gid, old_local = g2l.popitem(last=False)
available.append(old_local)
local_to_global.pop(old_local, None)
new_local = available.popleft()
g2l[gid] = new_local
local_to_global[new_local] = gid
new_mappings[gid] = new_local
# Pass 3: fill in all miss positions
for i, gid in enumerate(flat_global_cpu):
if local_ids_cpu[i].item() == -1:
local_ids_cpu[i] = new_mappings[gid]
return local_ids_cpu.to(original_device).view(original_shape)
def reset(self):
self.global_to_local.clear()
self.local_to_global.clear()
self.available_local_ids = deque(range(self.batch_size))
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
assert isinstance(kv_cache_spec, MambaSpec)
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
dtype=torch.int32,
device=device,
)
self.non_spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.spec_sequence_masks = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.bool,
device=device,
)
self.spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.non_spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.non_spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
device=device,
)
self.num_accepted_tokens = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
self.mapper = DeviceAwareLocalIdMapper(self.vllm_config.mlu_config.mamba_support_max_batch_size)
'''
==================
End of MLU Hijack
==================
'''
def vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: torch.Tensor | None = None,
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (
not self.use_spec_decode
or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
.sum()
.item()
== 0
):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True
)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1)
)
num_spec_decode_tokens = 0
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
if num_prefills == 0 and num_decodes == 0:
spec_token_size = min(
num_spec_decodes * (self.num_spec + 1),
query_start_loc[-1].item(),
)
spec_token_indx = torch.arange(
spec_token_size,
dtype=torch.int32,
device=query_start_loc.device,
)
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
non_spec_query_start_loc = None
else:
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]
spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = m.block_table_tensor[
~spec_sequence_masks, 0
]
spec_query_start_loc = torch.zeros(
num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
)
non_spec_query_start_loc = torch.zeros(
query_lens.size(0) - num_spec_decodes + 1,
dtype=torch.int32,
device=query_start_loc.device,
)
torch.cumsum(
query_lens[~spec_sequence_masks],
dim=0,
out=non_spec_query_start_loc[1:],
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
if num_prefills > 0:
has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None:
has_initial_state = has_initial_state[~spec_sequence_masks]
nums_dict, batch_ptr, token_chunk_offset_ptr = (
compute_causal_conv1d_metadata(non_spec_query_start_loc)
)
else:
has_initial_state = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)
assert non_spec_token_indx is not None and spec_token_indx is not None
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
non_spec_token_indx, non_blocking=True
)
non_spec_token_indx = self.non_spec_token_indx[
: non_spec_token_indx.size(0)
]
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
spec_token_indx, non_blocking=True
)
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
self.num_accepted_tokens[:num_spec_decodes].copy_(
num_accepted_tokens, non_blocking=True
)
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
num_accepted_tokens[num_spec_decodes:].fill_(1)
if (
self.use_full_cuda_graph
and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
:batch_size
]
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
non_spec_query_start_loc, non_blocking=True
)
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
non_spec_state_indices_tensor = self.mapper.batch_get_local_ids(non_spec_state_indices_tensor)
'''
==================
End of MLU Hijack
==================
'''
attn_metadata = GDNAttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_indx=spec_token_indx,
non_spec_token_indx=non_spec_token_indx,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
token_chunk_offset_ptr=token_chunk_offset_ptr,
)
return attn_metadata
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
GDNAttentionMetadataBuilder.__init__,
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder____init__)
MluHijackObject.apply_hijack(GDNAttentionMetadataBuilder,
GDNAttentionMetadataBuilder.build,
vllm__v1__attention__bachends__GDNAttentionMetadataBuilder__build)

View File

@@ -0,0 +1,934 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv, round_down
from vllm.attention.backends.utils import MLADims
from vllm.config import ModelConfig
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonPrefillMetadata,
MLACommonDecodeMetadata, MLACommonMetadata,
MLACommonMetadataBuilder, M, QueryLenSupport,
use_cudnn_prefill, use_flashinfer_prefill,
use_trtllm_ragged_deepseek_prefill,
FlashInferPrefillMetadata,
CudnnPrefillMetadata,
MLACommonImpl,
CUDNN_WORKSPACE_SIZE
)
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, split_decodes_and_prefills,
infer_global_hyperparameters, get_per_layer_parameters,
)
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
MLAAttentionImpl,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.v1.attention.backends.flash_attn import MLUFlashAttentionImpl
from vllm_mlu.v1.attention.backends.utils import (
MLUCommonAttentionMetadata, get_common_metadata,
MLUInferMode)
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.platforms import current_platform
from vllm import envs
try:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
flashinfer_available = True
except ImportError:
BatchPrefillWithRaggedKVCacheWrapper = object
flashinfer_available = False
logger = init_logger(__name__)
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class MLACommonBackend_MluHijack(MLACommonBackend):
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576, 512]
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
if model_config.hf_text_config.model_type == "deepseek_v4":
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.head_dim,
qk_nope_head_dim=hf_text_config.head_dim - hf_text_config.rope_head_dim,
qk_rope_head_dim=hf_text_config.rope_head_dim,
v_head_dim=hf_text_config.head_dim,
)
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
class MLACommonMetadataBuilder_MluHijack(MLACommonMetadataBuilder):
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[M] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (
metadata_cls if metadata_cls is not None else MLACommonMetadata
)
self.kv_cache_spec = kv_cache_spec
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
# Don't try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.kv_cache_spec.block_size
self.chunked_prefill_workspace_size = (
self.determine_chunked_prefill_workspace_size(vllm_config)
)
if self.dcp_world_size > 1:
# Note(hc): The local kvcache is incomplete when DCP is triggered,
# an additional kvcache allgather across the DCP group is therefore
# required, so the workspace has to be enlarged by 1/DCP relative
# to the original TP allocation.
assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
self.chunked_prefill_workspace = torch.empty(
(
self.chunked_prefill_workspace_size
+ self.chunked_prefill_workspace_size // self.dcp_world_size,
self.model_config.get_head_size(),
),
dtype=self.model_config.dtype,
device=device,
)
else:
self.chunked_prefill_workspace = torch.empty(
(
self.chunked_prefill_workspace_size,
self.model_config.get_head_size(),
),
dtype=self.model_config.dtype,
device=device,
)
self._use_cudnn_prefill = use_cudnn_prefill()
self._use_fi_prefill = use_flashinfer_prefill()
self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
self.prefill_metadata_cls = (
FlashInferPrefillMetadata
if self._use_fi_prefill
else CudnnPrefillMetadata
if self._use_cudnn_prefill
else MLACommonPrefillMetadata
)
if self._use_fi_prefill:
self._workspace_buffer = torch.empty(
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=device,
)
self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
)
if self._use_trtllm_ragged_prefill:
self._workspace_buffer = torch.empty(
envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=device,
)
if self._use_cudnn_prefill:
self.cudnn_workspace = torch.empty(
CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
dtype=torch.int8,
device=device,
)
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
)
# Validate consistency between query_len_support and reorder_batch_threshold
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
assert self.reorder_batch_threshold == 1, (
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
)
MluHijackObject.apply_hijack(MLACommonBackend,
MLACommonBackend.get_supported_head_sizes,
MLACommonBackend_MluHijack.get_supported_head_sizes)
MluHijackObject.apply_hijack(MLACommonMetadataBuilder,
MLACommonMetadataBuilder.__init__,
MLACommonMetadataBuilder_MluHijack.__init__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashMLAImpl"]:
return FlashMLAImpl
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def get_kv_cache_scale_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
) -> tuple[int, ...]:
return (1, num_blocks, num_kv_heads, block_size)
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576, 512]
@dataclass
class FlashMLAPrefillMetadata(MLACommonPrefillMetadata):
num_prefills: int = -1 # for gather_cache
max_seq_len: int = -1 # for attn forward
@property
def block_tables(self):
return self.block_table
@property
def context_chunk_cu_seq_lens(self):
if self.chunked_context is None:
return None
return self.chunked_context.cu_seq_lens
@property
def context_chunk_starts(self):
if self.chunked_context is None:
return None
return self.chunked_context.starts
@property
def context_chunk_seq_tot(self):
if self.chunked_context is None:
return None
return self.chunked_context.seq_tot
@property
def context_chunk_max_seq_lens(self):
if self.chunked_context is None:
return None
return self.chunked_context.max_seq_lens
@property
def context_chunk_workspace(self):
if self.chunked_context is None:
return None
return self.chunked_context.workspace
@dataclass
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
tile_scheduler_metadata: torch.Tensor
num_splits: torch.Tensor
# add for mlu rope and attn forward
query_start_loc: torch.Tensor # for rope
max_query_len: int # for rope
max_seq_len:int = -1 # for attn forward
@dataclass
class FlashMLAMetadata(MLACommonMetadata):
num_prefill_tokens: Optional[int] = None
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
# ^ TODO(matt): tune this
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
)
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: 1. set decoder_query_len for mtp
@brief: 2. init chunk workspace for prefix_caching only
@brief: 3. set prefill_metadata_cls
@brief: 4. add deepseek v3.2 infos
'''
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
speculative_config = vllm_config.speculative_config
self.num_speculative_tokens = (speculative_config.num_speculative_tokens
if speculative_config is not None else 0)
self.decoder_query_len = 1 + self.num_speculative_tokens
self.max_model_len = self.model_config.max_model_len
self.is_deepseek_v32 = self.model_config.hf_text_config.model_type == "deepseek_v32"
self.enable_caching = cache_config.enable_prefix_caching
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if (not self.is_deepseek_v32 and not self.chunked_prefill_enabled and
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and self.enable_caching)):
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * self.model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.prefill_metadata_cls = FlashMLAPrefillMetadata
'''
==================
End of MLU Hijack
==================
'''
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
# mlu v1 mtp forces decoder_query_len = 1 for k > 1, so we should set again
self.decoder_query_len = 1 + self.num_speculative_tokens
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
'''
=============================
Modify by vllm_mlu
=============================
@brief: record prefill and decode requests and token nums to call
chunked fa and single-query attn respectively in forward.
@Notes: decodes need all prompt tokens are computed.
'''
req_index = input_batch.req_id_to_index.get(req_id)
all_prompt_tokens_has_computed = (
input_batch.num_computed_tokens_cpu[req_index] >=
input_batch.num_prompt_tokens[req_index])
if num_tokens <= self.decoder_query_len and all_prompt_tokens_has_computed:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
'''
==================
End of MLU Hijack
==================
'''
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch
def _build_decode(
self,
block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
query_start_loc: torch.Tensor,
max_query_len: int,
max_seq_len: int,
) -> FlashMLADecodeMetadata:
'''
=============================
Modify by vllm_mlu
=============================
@brief: set tile_scheduler_metadata and num_splits to None.
@brief: set dcp_tot_seq_lens_device.
'''
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
tile_scheduler_metadata=None,
num_splits=None,
dcp_tot_seq_lens=None,
# for mlu
max_seq_len=max_seq_len,
query_start_loc=query_start_loc,
max_query_len=max_query_len
)
'''
==================
End of MLU Hijack
==================
'''
def build_for_cudagraph_capture(
self, common_attn_metadata: MLUCommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
if m.infer_mode == MLUInferMode.DECODE_ONLY:
assert m.num_reqs * m.max_query_len == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
return self.build(0, m)
def build(self,
common_prefix_len: int,
common_attn_metadata: MLUCommonAttentionMetadata,
fast_build: bool = False,
input_batch: "InputBatch" = None) -> M:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
query_seq_lens_cpu)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support normal and mtp input split
'''
if input_batch is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata,
self.decoder_query_len)
else:
num_decodes, num_prefills = input_batch.split_decodes_and_prefills()
num_decode_tokens = common_attn_metadata.query_start_loc_cpu[num_decodes].item()
num_prefill_tokens = num_tokens - num_decode_tokens
'''
==================
End of MLU Hijack
==================
'''
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
'''
=============================
Modify by vllm_mlu
=============================
@brief: avoid buffer missing when prefill_only + mlugraph
'''
if num_decodes > 0:
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
else:
prefill_query_start_loc= query_start_loc
'''
==================
End of MLU Hijack
==================
'''
chunked_context_metadata = None
if ((self.chunked_prefill_enabled or
(mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED and
self.enable_caching and
common_attn_metadata.is_chunked)
) and num_prefills > 0 and max_context_len_cpu > 0):
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
if self.is_deepseek_v32:
max_context_chunk = self.max_model_len
else:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
if self.aot_schedule:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk,
self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata_cls = \
FlashMLAPrefillMetadata.ChunkedContextMetadata
chunked_context_metadata = \
chunked_context_metadata_cls(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
workspace=getattr(self, "chunked_prefill_workspace", None),
)
if not self.is_deepseek_v32:
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
prefill_metadata = self.prefill_metadata_cls(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
# for mlu
num_prefills=num_prefills,
max_seq_len=common_attn_metadata.seq_lens_cpu[reqs_start:].max().item(),
)
decode_metadata = None
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
query_start_loc=query_start_loc[:num_decodes + 1],
max_query_len=query_seq_lens_cpu[:num_decodes].max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu[:num_decodes].max().item(),
)
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=num_decodes,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
prefill=prefill_metadata,
decode=decode_metadata,
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: MLUCommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == self.decoder_query_len
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class FlashMLAImpl(MLUFlashAttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
kwargs: Optional[dict[str, Any]] = {},
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
out_lse = None
# use default common metadata if kwargs does not have common_metadata
common_metadata: MLUCommonAttentionMetadata = kwargs.get("common_metadata", None)
if common_metadata is None:
common_metadata = get_common_metadata()
only_prefill = kwargs.get("only_prefill", False)
only_decode = kwargs.get("only_decode", False)
attn_bias = kwargs.get("attn_bias", None)
assert only_prefill != only_decode, "only_prefill and only_decode cannot be True and False at the same time."
if only_prefill:
cu_seqlens_q = attn_metadata.prefill.query_start_loc
cu_seqlens_kv = common_metadata.query_start_loc
seqused_k = common_metadata.seq_lens[attn_metadata.num_decodes:]
max_seqlen_q = attn_metadata.prefill.max_query_len
max_seqlen_k = attn_metadata.prefill.max_seq_len
block_table = attn_metadata.prefill.block_table
num_actual_tokens = attn_metadata.num_prefill_tokens
else:
cu_seqlens_q = None # nouse
cu_seqlens_kv = None # nouse
seqused_k = common_metadata.seq_lens[:attn_metadata.num_decodes]
max_seqlen_q = None # nouse
max_seqlen_k = common_metadata.max_seq_len
block_table = attn_metadata.decode.block_table
num_actual_tokens = attn_metadata.num_decode_tokens
skip_process_cache = ((self.use_mla
and (common_metadata.is_prefill_only
or self.use_fused_mla_qkv
or only_prefill))
or self.kv_sharing_target_layer_name is not None)
kv_cache_, kv_cache_scale_, kv_cache_index_ = kv_cache
key_cache = kv_cache_[0]
value_cache = None if self.use_mla else kv_cache_[1]
key_cache_scale, value_cache_scale = None, None
if kv_cache_scale_.numel() > 0:
key_cache_scale = kv_cache_scale_[0]
value_cache_scale = None if self.use_mla else kv_cache_scale_[1]
if not skip_process_cache:
if is_quantized_kv_cache(self.kv_cache_dtype):
mlu_ops.quant_to_paged_cache(
k=key[:num_actual_tokens],
v=(None if self.use_mla else value[:num_actual_tokens]),
k_cache=key_cache,
v_cache=value_cache,
k_cache_quant_scale=key_cache_scale,
v_cache_quant_scale=value_cache_scale,
slot_mapping=attn_metadata.slot_mapping.flatten(),
)
else:
mlu_ops.reshape_paged_cache(
k=key[:num_actual_tokens],
v=(None if self.use_mla else value[:num_actual_tokens]),
k_cache=key_cache,
v_cache=value_cache,
slot_mapping=attn_metadata.slot_mapping.flatten()
)
alibi_slopes = None if self.alibi_slopes is None else \
self.alibi_slopes.repeat(seqused_k.shape[0], 1)
if kwargs.get("model_type", "") == "deepseek_v32":
from vllm_mlu.model_executor.models.sp_utils import get_sp_forward_context
sp_context = get_sp_forward_context()
if sp_context is not None and sp_context.is_v32:
num_actual_tokens = sp_context.sp_attn_metadata.num_prefill_tokens
decode_query = query[:num_actual_tokens].view(-1, self.num_heads, self.head_size)
head_size_v = value.shape[-1] if self.use_mla else self.head_size
decode_output = output[:num_actual_tokens].view(-1, self.num_heads, head_size_v)
decode_query = query.unsqueeze(1) # see tokens as batch dim
decode_output = decode_output.unsqueeze(1)
q_quant_scale = kwargs.get("q_quant_scale", None)
if q_quant_scale is not None:
q_quant_scale = q_quant_scale[:num_actual_tokens].view(-1, self.num_heads)
q_quant_scale = q_quant_scale.unsqueeze(1)
mlu_ops.single_query_cached_kv_attn(
q=decode_query,
k_cache=key_cache,
v_cache=value_cache,
out=decode_output,
block_tables=kwargs.get("new_block_tables", None),
context_lens=kwargs.get("new_context_lens", None),
k_cache_quant_scale=key_cache_scale,
v_cache_quant_scale=value_cache_scale,
alibi_slopes=alibi_slopes,
max_contxt_len=kwargs.get("index_topk", None),
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
softmax_scale=self.scale,
head_size_v=(-1 if not self.use_mla else head_size_v),
compute_dtype=compute_dtype,
q_quant_scale=q_quant_scale,
decoder_attn_dtype=self.decoder_attn_dtype,
mask=attn_bias,
)
return output
if common_metadata.is_prefill_only or only_prefill:
# prefill only
prefill_causal = kwargs.get("prefill_causal", True)
cu_seqlens_q = kwargs.get("cu_seq_lens_q", cu_seqlens_q)
cu_seqlens_kv = kwargs.get("cu_seq_lens_kv", cu_seqlens_kv)
max_seqlen_q = kwargs.get("max_seq_len_q", max_seqlen_q)
max_seqlen_k = kwargs.get("max_seq_len_kv", max_seqlen_k)
return_lse = kwargs.get("return_lse", False)
num_prefill_query_tokens = common_metadata.num_prefill_query_tokens
num_prefill_kv_tokens = common_metadata.num_prefill_kv_tokens
use_f32 = attn_bias is not None and attn_bias.dtype == torch.float32
if use_f32:
f32_output = torch.empty_like(output, dtype=torch.float32)
attn_output_list = mlu_ops.flash_attention(
q=query[:num_prefill_query_tokens].to(torch.float32) if use_f32 else query[:num_prefill_query_tokens],
k=key[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else key[:num_prefill_kv_tokens],
v=value[:num_prefill_kv_tokens].to(torch.float32) if use_f32 else value[:num_prefill_kv_tokens],
out=f32_output[:num_prefill_query_tokens] if use_f32 else output[:num_prefill_query_tokens],
cu_seq_lens_q=cu_seqlens_q,
cu_seq_lens_kv=cu_seqlens_kv,
alibi_slope=alibi_slopes,
attn_bias=attn_bias,
max_seq_len_q=max_seqlen_q,
max_seq_len_kv=max_seqlen_k,
softmax_scale=self.scale,
is_causal=prefill_causal,
window_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
window_size_right=(-1 if self.sliding_window is None else self.sliding_window[1]),
compute_dtype=self.prefill_compute_dtype,
return_lse=return_lse,
q_quant_dtype=self.prefill_q_dtype,
k_quant_dtype=self.prefill_k_dtype,
v_quant_dtype=self.prefill_v_dtype
)
if use_f32:
output[:num_prefill_query_tokens].copy_(f32_output[:num_prefill_query_tokens])
if return_lse:
out_lse = attn_output_list[1]
else:
batch_size = block_table.shape[0]
# decode only
decode_query = query[:num_actual_tokens].view(batch_size, -1, self.num_heads, self.head_size)
head_size_v = value.shape[-1] if self.use_mla else self.head_size
decode_output = output[:num_actual_tokens].view(batch_size, -1, self.num_heads, head_size_v)
q_quant_scale = kwargs.get("q_quant_scale", None)
if q_quant_scale is not None:
q_quant_scale = q_quant_scale[:num_actual_tokens].view(batch_size, -1, self.num_heads)
mlu_ops.single_query_cached_kv_attn(
q=decode_query,
k_cache=key_cache,
v_cache=value_cache,
out=decode_output,
block_tables=block_table,
context_lens=seqused_k,
k_cache_quant_scale=key_cache_scale,
v_cache_quant_scale=value_cache_scale,
alibi_slopes=alibi_slopes,
max_contxt_len=max_seqlen_k,
windows_size_left=(-1 if self.sliding_window is None else self.sliding_window[0]),
windows_size_right=(-1 if self.sliding_window is None else self.sliding_window[0]),
softmax_scale=self.scale,
head_size_v=(-1 if not self.use_mla else head_size_v),
compute_dtype=attn_metadata.decode.compute_dtype,
q_quant_scale=q_quant_scale,
decoder_attn_dtype=self.decoder_attn_dtype,
mask=attn_bias,
)
return output if out_lse is None else (output, out_lse)

View File

@@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
import numpy as np
import pandas as pd
import torch
from typing import TYPE_CHECKING, Union
from dataclasses import dataclass
from enum import Enum
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.forward_context import get_forward_context
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
COMMON_METADATA_STR: str = "common_metadata"
class MLUInferMode(Enum):
CHUNKED = 1
PREFILL_ONLY = 2
DECODE_ONLY = 3
@classmethod
def build(
cls,
max_query_len,
max_computed_tokens,
uniform_decode_query_len: int = 1,
) -> Enum:
if max_query_len <= uniform_decode_query_len:
return MLUInferMode.DECODE_ONLY
elif max_computed_tokens == 0:
return MLUInferMode.PREFILL_ONLY
else:
return MLUInferMode.CHUNKED
@property
def is_prefill_only(self):
return self == MLUInferMode.PREFILL_ONLY
@property
def is_decode_only(self):
return self == MLUInferMode.DECODE_ONLY
@property
def is_chunked(self):
return self == MLUInferMode.CHUNKED
@dataclass
class MLUCommonAttentionMetadata(CommonAttentionMetadata):
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
seq_start_loc: torch.Tensor | None = None
seq_start_loc_cpu: torch.Tensor | None = None
"""(batch_size + 1,), the start location of each request in the input key/value sequence."""
num_input_tokens: int = 0
"""Number of query tokens with padding."""
num_prefill_query_tokens: int = 0
"""Number of query tokens in prefill phase."""
num_prefill_kv_tokens: int = 0
"""Number of key/value tokens in prefill phase."""
infer_mode: MLUInferMode | None = None
"""Inference mode for flash attention."""
@property
def is_prefill_only(self):
return self.infer_mode == MLUInferMode.PREFILL_ONLY
@property
def is_decode_only(self):
return self.infer_mode == MLUInferMode.DECODE_ONLY
@property
def is_chunked(self):
return self.infer_mode == MLUInferMode.CHUNKED
@classmethod
def build(
cls,
query_start_loc, query_start_loc_cpu,
seq_lens, seq_lens_cpu,
num_computed_tokens_cpu,
num_reqs, num_actual_tokens, max_query_len,
block_table_tensor, slot_mapping,
seq_start_loc, is_start_loc_match,
num_input_tokens: int = 0,
num_speculative_tokens: int = 0,
has_prefill_reqs: bool = False
):
"""Build attention metadata for MLU inference.
Args:
has_prefill_reqs: Whether there are pending prefill requests with chunked.
"""
infer_mode = None
if is_start_loc_match:
infer_mode = MLUInferMode.PREFILL_ONLY
elif max_query_len <= (1 + num_speculative_tokens) and (not has_prefill_reqs):
infer_mode = MLUInferMode.DECODE_ONLY
else:
infer_mode = MLUInferMode.CHUNKED
num_input_tokens = (
num_actual_tokens if num_input_tokens == 0
else num_input_tokens
)
max_seq_len = int(seq_lens_cpu.max())
return cls(query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
seq_start_loc=seq_start_loc,
seq_start_loc_cpu=seq_start_loc.to("cpu", non_blocking=True),
num_input_tokens=num_input_tokens,
infer_mode=infer_mode,
num_prefill_query_tokens=num_actual_tokens,
num_prefill_kv_tokens=num_actual_tokens)
def save(self, infer_phase: str):
csv_path = os.getenv("VLLM_STEP_INPUT_CSV_PATH", None)
if not csv_path:
return
header = [
"infer_phase", "infer_mode", "num_reqs", "num_actual_tokens",
"max_query_len", "max_seq_len", "query_start_loc", "seq_lens"
]
data = [
infer_phase, self.infer_mode, self.num_reqs,
self.num_actual_tokens, self.max_query_len, self.max_seq_len,
str(self.query_start_loc_cpu.tolist()),
str(self.seq_lens_cpu.tolist())
]
data_dict = dict(zip(header, data))
df_csv = pd.DataFrame(data_dict, index=[0])
if infer_phase == "RealInfer":
print(df_csv.to_string())
try:
if dir_path := os.path.dirname(csv_path):
os.makedirs(dir_path, exist_ok=True)
append = False
if os.path.isfile(csv_path):
try:
df_old = pd.read_csv(csv_path)
append = (df_old.columns.tolist() == header)
except Exception as e:
raise RuntimeError(f"Existing {csv_path} failed to be read and will be overwritten")
if append:
df_csv.to_csv(csv_path, mode='a', header=False, index=False)
else:
df_csv.to_csv(csv_path, index=False)
except Exception as e:
raise RuntimeError(f"Invalid VLLM_STEP_INPUT_CSV_PATH: {csv_path} to dump step inputs, Error: {e}")
def get_common_metadata_from_attn_metadata(
attn_metadata) -> Union[MLUCommonAttentionMetadata, None]:
"""
Get MLUCommonAttentionMetadata for MLU-V1 inference.
Use outside of set_forward_context().
"""
if attn_metadata is None:
return
assert (isinstance(attn_metadata, dict)
and COMMON_METADATA_STR in attn_metadata), \
f"MLU-V1 only support type(attn_metadata)=dict, and " + \
f"{COMMON_METADATA_STR} in attn_metadata. Now, type(attn_metadata)=" + \
f"{type(attn_metadata)}, or {COMMON_METADATA_STR} not in attn_metadata."
return attn_metadata[COMMON_METADATA_STR]
def get_common_metadata() -> Union[MLUCommonAttentionMetadata, None]:
"""
Get MLUCommonAttentionMetadata for MLU-V1 inference.
Use inside of set_forward_context().
"""
attn_metadata = get_forward_context().attn_metadata
return get_common_metadata_from_attn_metadata(attn_metadata)
def unpad_common_attn_metadata(
common_metadata: MLUCommonAttentionMetadata,
num_reqs: int,
num_scheduled_tokens: int,
):
"""
Unpad MLUCommonAttentionMetadata by given num_reqs and num_scheduled_tokens.
"""
common_metadata.num_reqs = num_reqs
common_metadata.num_input_tokens = num_scheduled_tokens
common_metadata.query_start_loc = common_metadata.query_start_loc[:num_reqs + 1]
common_metadata.query_start_loc_cpu = common_metadata.query_start_loc_cpu[:num_reqs + 1]
common_metadata.seq_start_loc = common_metadata.seq_start_loc[:num_reqs + 1]
common_metadata.seq_lens = common_metadata.seq_lens[:num_reqs]
common_metadata.seq_lens_cpu = common_metadata.seq_lens_cpu[:num_reqs]
common_metadata.block_table_tensor = common_metadata.block_table_tensor[:num_reqs]
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch into decode → extend → prefill order
# where:
# decode: request with num_scheduled_tokens <= decode_threshold
# extend: non-decode request with existing context
# prefill: non-decode request with no existing context
# NOTE for now we loosely use "decode" to mean requests where attention is
# likely memory-bound and "prefill" to mean requests where attention is
# likely compute-bound,
num_reqs = len(input_batch.req_ids)
num_scheduled_tokens = [
scheduler_output.num_scheduled_tokens[id] for id in input_batch.req_ids
]
num_scheduled_tokens_np = np.array(num_scheduled_tokens)
num_computed_tokens_np = input_batch.num_computed_tokens_cpu[:num_reqs]
'''
=============================
Modify by vllm_mlu
=============================
@brief: enhence decode mode condition that all prompt tokens are computed.
'''
# is_decode = num_scheduled_tokens_np <= decode_threshold
is_decode = (
(num_scheduled_tokens_np <= decode_threshold)
& (num_computed_tokens_np >= input_batch.num_prompt_tokens[:num_reqs])
)
'''
==================
End of MLU Hijack
==================
'''
is_extend = (~is_decode) & (num_computed_tokens_np > 0)
is_prefill = (~is_decode) & (num_computed_tokens_np == 0)
# Desired order: decode → extend → prefill
req_regions = np.zeros(is_decode.shape, dtype=np.int32) # 0 = decode by default
req_regions[is_extend] = 1
req_regions[is_prefill] = 2
num_decodes = int(is_decode.sum())
num_extends = int(is_extend.sum())
target_regions = np.zeros(num_reqs, dtype=np.int32)
target_regions[num_decodes : num_decodes + num_extends] = 1
target_regions[num_decodes + num_extends :] = 2
needs_swap = req_regions != target_regions
if not needs_swap.any():
return False
# Extract indices that need swapping and sort by target region
orig_indices = np.where(needs_swap)[0]
sorted_order = np.argsort(req_regions[needs_swap], kind="stable")
src_indices = orig_indices[sorted_order]
src_dest_map = {int(src): int(dst) for src, dst in zip(src_indices, orig_indices)}
for src in src_dest_map:
dst = src_dest_map[src]
while src != dst:
input_batch.swap_states(src, dst)
# Mark dst as done by updating its destination to itself
next_dst = src_dest_map.get(dst, dst)
src_dest_map[dst] = dst
dst = next_dst
return True

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, overload
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
logger = init_logger(__name__)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class KVCacheManager_MluHijack(KVCacheManager):
def allocate_slots(
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
fixed_window_tokens: int = 0,
) -> KVCacheBlocks | None:
"""Add slots for a request with new tokens to append.
Args:
request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
```
The following *_blocks are illustrated in this layout.
Returns:
A list of new allocated blocks.
"""
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.coordinator.remove_skipped_blocks(
request.request_id, request.num_computed_tokens
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens + fixed_window_tokens,
self.max_model_len,
)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not any(new_computed_block_list), (
"Computed blocks should be empty when prefix caching is disabled"
)
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks(
request.request_id, new_computed_block_list
)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot, num_encoder_tokens
)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return self.create_kv_cache_blocks(new_blocks)
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
# draft tokens that could be rejected). Therefore, we cap the number
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(
num_computed_tokens + num_new_tokens, request.num_tokens
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
return self.create_kv_cache_blocks(new_blocks)
MluHijackObject.apply_hijack(KVCacheManager,
KVCacheManager.allocate_slots,
KVCacheManager_MluHijack.allocate_slots)

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import (
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
KVCacheTensor,
UniformTypeKVCacheSpecs,
)
from vllm.v1.core import kv_cache_utils
from vllm.v1.core.kv_cache_utils import (may_override_num_blocks,
get_uniform_page_size,
get_num_blocks)
logger = init_logger(__name__)
def vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
Generate the KV cache configuration from the KV cache groups and spec
of each layer.
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
"""
if len(kv_cache_groups) == 0:
# Attention free models do not have KV cache.
# Return num_blocks=1 as BlockPool always needs a null_block.
return KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Determine how model runners should initialize the KV cache tensors.
if len(kv_cache_groups) == 1 and isinstance(
kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs
):
# Special case: all layers have the same type of KV cache but with
# different hidden size. Allocate different amount of memory for each
# layer based on its hidden size.
num_blocks = (
available_memory // kv_cache_groups[0].kv_cache_spec.page_size_bytes
)
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
kv_cache_tensors = [
KVCacheTensor(
size=per_layer_specs[layer_name].page_size_bytes * num_blocks,
shared_by=[layer_name],
)
for layer_name in kv_cache_groups[0].layer_names
]
else:
# General case:
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
if (vllm_config.mlu_config.enable_mamba_split_page_size):
# Note(wulingchao): 预留出linear attention的内存不参与系统调度
# 当前的 page size是小page需要扩展到完整的linear attention的page
mamba_page_size = (page_size \
* vllm_config.mlu_config.mamba_to_attn_block_ratio
* vllm_config.mlu_config.mamba_support_max_batch_size \
* group_size * 3)
logger.warning(f"all available memory {available_memory}, mamba mem used {mamba_page_size}")
available_memory = available_memory - mamba_page_size
'''
==================
End of MLU Hijack
==================
'''
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size
)
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks, shared_by=shared_by)
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
MluHijackObject.apply_hijack(kv_cache_utils,
kv_cache_utils.get_kv_cache_config_from_groups,
vllm__v1__core__kv_cache_utils__get_kv_cache_config_from_groups)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
from vllm_mlu.v1.core.sched.scheduler import MLUUnchunkScheduler, SchedulerWithProfiler
logger = init_logger(__name__)
class AsyncScheduler(SchedulerWithProfiler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (
request.num_computed_tokens
== request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# Wwe will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids
)
# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0
# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request, request.num_computed_tokens - request.num_output_placeholders
)
return new_token_ids, stopped
class MLUUnchunkAsyncScheduler(MLUUnchunkScheduler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, []))
if (
request.num_computed_tokens
== request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add a placeholder for the new token in spec_token_ids.
# because the actual token id is not known yet. so just use -1
# as a placeholder and the length of spec_token_ids is set to
# self.num_spec_tokens. we will update the actual spec token id
# in worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids)
# num_output_placeholders = 0 happend when a request is preempted.
# a preempted request will be added to waiting queue again and
# num_output_placeholders is reset to 0,
# so don't need to revert num_output_placeholders for this situation.
if request.num_output_placeholders > 0:
# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0
# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request,
request.num_computed_tokens - request.num_output_placeholders)
return new_token_ids, stopped
def _update_computed_tokens_after_speculation(
self, request: Request, num_rejected: int
):
"""Update the computed tokens for each request, which is necessary
for spec decoding. In sync scheduler, we need to revert
num_computed_tokens by num_rejected tokens,
but in async scheduler, we also need to revert num_output_placeholders
by num_rejected tokens for spec decoding.
"""
# num_computed_tokens = 0 happend when a request is preempted.
# a preempted request will be added to waiting queue again and
# num_computed_tokens is reset to 0,
# so don't need to revert num_computed_tokens for this situation.
if request.num_computed_tokens > 0:
# when spec decoding is enabled, num_output_placeholders
# is increased by num_spec_tokens in _update_after_schedule.
# update num_output_placeholders here to reflect the actual number
# of accepted output tokens.
request.num_output_placeholders -= num_rejected
super()._update_computed_tokens_after_speculation(request, num_rejected)

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING
from typing_extensions import deprecated
from vllm._bc_linter import bc_linter_include
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
else:
ECConnectorMetadata = object
KVConnectorMetadata = object
LoRARequest = object
MultiModalFeatureSpec = object
PoolingParams = object
SamplingParams = object
Request = object
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add new_toked_ids to pass the first token generated
by the prefiller to the decoder's model_runner.
'''
@bc_linter_include
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: LoRARequest | None
new_token_ids: list[list[int]]
prompt_embeds: "torch.Tensor | None" = None
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
new_token_ids=request._output_token_ids,
)
def __repr__(self) -> str:
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
return (
f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape},"
f"new_token_ids={self.new_token_ids}"
")"
)
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self) -> str:
prompt_token_ids_len = (
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
)
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
return (
f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={prompt_token_ids_len},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")"
)
'''
==================
End of MLU Hijack
==================
'''

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager,
SlidingWindowManager,
spec_manager_map,
)
from vllm_mlu.v1.kv_cache_interface import (
MLUFullAttentionSpec,
MLUMLAAttentionSpec,
MLUSlidingWindowSpec,
)
spec_manager_map.update({
MLUFullAttentionSpec: FullAttentionManager,
MLUSlidingWindowSpec: SlidingWindowManager,
MLUMLAAttentionSpec: FullAttentionManager,
})

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.v1.engine.async_llm import AsyncLLM
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class AsyncLLM_MluHijack(AsyncLLM):
async def start_scheduler_profile(self) -> None:
await self.engine_core.start_scheduler_profile()
async def stop_scheduler_profile(self) -> None:
await self.engine_core.stop_scheduler_profile()
MluHijackObject.apply_hijack(AsyncLLM,
"start_scheduler_profile",
AsyncLLM_MluHijack.start_scheduler_profile)
MluHijackObject.apply_hijack(AsyncLLM,
"stop_scheduler_profile",
AsyncLLM_MluHijack.stop_scheduler_profile)

566
vllm_mlu/v1/engine/core.py Normal file
View File

@@ -0,0 +1,566 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from collections import deque
import signal
from typing import Any, Callable, cast
from concurrent.futures import Future
from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import engine_receiver_cache_from_config
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_utils import BlockHash, get_request_block_hasher, init_none_hash
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.engine.core import (
EngineCore,
EngineCoreProc,
DPEngineCoreProc,
)
from vllm.v1.executor.abstract import Executor
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
from logging import DEBUG
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.mlu_metric import LLMMetric
class EngineCore_MluHijack(EngineCore):
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Callable | None = None,
):
'''
=============================
Modify by vllm_mlu
=============================
@brief: load_general_plugins in run_engine_core
'''
# # plugins need to be loaded at the engine/scheduler level too
# from vllm.plugins import load_general_plugins
# load_general_plugins()
'''
==================
End of MLU Hijack
==================
'''
self.vllm_config = vllm_config
if vllm_config.parallel_config.data_parallel_rank == 0:
logger.info(
"Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION,
vllm_config,
)
self.log_stats = log_stats
# Setup Model.
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(executor_fail_callback)
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
if len(kv_cache_config.kv_cache_groups) == 0:
# Encoder models without KV cache don't support
# chunked prefill. But do SSM models?
logger.info("Disabling chunked prefill for model without KVCache")
vllm_config.scheduler_config.enable_chunked_prefill = False
scheduler_block_size = (
vllm_config.cache_config.block_size
* vllm_config.parallel_config.decode_context_parallel_size
)
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size > 1,
log_stats=self.log_stats,
block_size=scheduler_block_size,
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config(
vllm_config, mm_registry
)
# If a KV connector is initialized for scheduler, we want to collect
# handshake metadata from all workers so the connector in the scheduler
# will have the full context
kv_connector = self.scheduler.get_kv_connector()
if kv_connector is not None:
# Collect and store KV connector xfer metadata from workers
# (after KV cache registration)
xfer_handshake_metadata = (
self.model_executor.get_kv_connector_handshake_metadata()
)
if xfer_handshake_metadata:
# xfer_handshake_metadata is list of dicts from workers
# Each dict already has structure {tp_rank: metadata}
# Merge all worker dicts into a single dict
content: dict[int, Any] = {}
for worker_dict in xfer_handshake_metadata:
if worker_dict is not None:
content.update(worker_dict)
kv_connector.set_xfer_handshake_metadata(content)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: (
deque[tuple[Future[ModelRunnerOutput], SchedulerOutput]] | None
) = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)
self.ec_producer = (
vllm_config.ec_transfer_config is not None
and vllm_config.ec_transfer_config.is_ec_producer
)
self.is_pooling_model = vllm_config.model_config.runner_type == "pooling"
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
scheduler_block_size, caching_hash_fn
)
self.step_fn = (
self.step if self.batch_queue is None else self.step_with_batch_queue
)
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
'''
=============================
Modify by vllm_mlu
=============================
@brief: v1 support offline benchmark
'''
self.step_latency = []
self.model_exec_latency = []
self.mm_encoder_latency = []
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
'''
==================
End of MLU Hijack
==================
'''
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: v1 support offline benchmark
'''
if mlu_envs.VLLM_LATENCY_DEBUG_EN:
step_start = LLMMetric.get_mlu_cost_time()
'''
==================
End of MLU Hijack
==================
'''
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
if self.use_spec_decode and \
self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.kv_role == "kv_producer":
draft_token_ids = self.model_executor.take_draft_token_ids()
self.scheduler.draft_token_ids = draft_token_ids
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: v1 support offline benchmark
'''
has_sched_reqs = (scheduler_output.total_num_scheduled_tokens > 0)
if mlu_envs.VLLM_LATENCY_DEBUG_EN and has_sched_reqs:
self.step_latency.append(LLMMetric.get_mlu_cost_time() - step_start)
if mlu_envs.VLLM_LATENCY_DEBUG_WITH_DEVICE_EN and has_sched_reqs:
self.model_exec_latency.append(self.get_model_exec_latency())
mm_encoder_latency = self.get_mm_encoder_latency()
if mm_encoder_latency:
self.mm_encoder_latency.append(mm_encoder_latency)
'''
==================
End of MLU Hijack
==================
'''
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
def step_with_batch_queue(
self,
) -> tuple[dict[int, EngineCoreOutputs] | None, bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
batch_queue = self.batch_queue
assert batch_queue is not None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
assert len(batch_queue) < self.batch_queue_size
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
if not self.ec_producer:
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if self.is_pooling_model or not model_executed:
# No sampling required (no requests scheduled).
future = cast(Future[ModelRunnerOutput], exec_future)
else:
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
if not scheduler_output.pending_structured_output_tokens:
# We aren't waiting for any tokens, get any grammar output
# and sample immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
if not deferred_scheduler_output:
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
'''
=============================
Modify by vllm_mlu
=============================
@brief: supoort disagg for mlu.
'''
if self.use_spec_decode and \
self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.kv_role == "kv_producer":
draft_token_ids = self.model_executor.take_draft_token_ids()
self.scheduler.draft_token_ids = draft_token_ids
'''
==================
End of MLU Hijack
==================
'''
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
def get_model_exec_latency(self):
latency = self.model_executor.get_latency()
return latency
def get_mm_encoder_latency(self):
return self.model_executor.get_mm_encoder_latency()
def get_hfu_info(self, batch, input_len, output_len):
return self.model_executor.get_hfu_info(batch, input_len, output_len)
def get_latency(self):
return (self.step_latency, self.model_exec_latency, self.mm_encoder_latency)
def get_memory_usage(self):
peak_memory, block_memory = self.model_executor.get_memory_usage()
return (peak_memory, block_memory,
self.num_gpu_blocks, self.num_cpu_blocks)
def recapture_model(self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int):
self.model_executor.recapture_model(
prefill_enable_mlugraph, batch_size, input_len)
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
self.step_latency = []
self.model_exec_latency = []
self.mm_encoder_latency = []
mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED = use_unchunk_sched
mlu_envs.VLLM_V1_MIN_PREFILL_BATCH = min_prefill_batch
def start_scheduler_profile(self):
self.scheduler.start_scheduler_profile()
def stop_scheduler_profile(self):
self.scheduler.stop_scheduler_profile()
def response_remote_alloc_once(self):
self.model_executor.response_remote_alloc_once()
class EngineCoreProc_MluHijack(EngineCoreProc):
@staticmethod
def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, **kwargs):
"""Launch EngineCore busy loop in background process."""
'''
=============================
Modify by vllm_mlu
=============================
@brief: load_general_plugins for mp backend engine
'''
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
load_general_plugins()
'''
==================
End of MLU Hijack
==================
'''
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: EngineCoreProc | None = None
try:
parallel_config: ParallelConfig = kwargs["vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
set_process_title("EngineCore", f"DP{dp_rank}")
decorate_logs()
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
set_process_title("EngineCore")
decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while (
not self.engines_running
and not self.scheduler.has_requests()
and not self.batch_queue
):
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.kv_role == "kv_consumer":
self.response_remote_alloc_once()
if self.input_queue.empty():
continue
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
else:
req = self.input_queue.get()
self._handle_client_request(*req)
if waited:
logger.debug("EngineCore loop active.")
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.kv_role == "kv_consumer":
self.response_remote_alloc_once()
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
MluHijackObject.apply_hijack(EngineCore,
"get_mm_encoder_latency",
EngineCore_MluHijack.get_mm_encoder_latency)
MluHijackObject.apply_hijack(EngineCore,
"get_model_exec_latency",
EngineCore_MluHijack.get_model_exec_latency)
MluHijackObject.apply_hijack(EngineCore,
"get_hfu_info",
EngineCore_MluHijack.get_hfu_info)
MluHijackObject.apply_hijack(EngineCore,
"get_latency",
EngineCore_MluHijack.get_latency)
MluHijackObject.apply_hijack(EngineCore,
"get_memory_usage",
EngineCore_MluHijack.get_memory_usage)
MluHijackObject.apply_hijack(EngineCore,
"recapture_model",
EngineCore_MluHijack.recapture_model)
MluHijackObject.apply_hijack(EngineCore,
"init_metric",
EngineCore_MluHijack.init_metric)
MluHijackObject.apply_hijack(EngineCore,
"start_scheduler_profile",
EngineCore_MluHijack.start_scheduler_profile)
MluHijackObject.apply_hijack(EngineCore,
"stop_scheduler_profile",
EngineCore_MluHijack.stop_scheduler_profile)
MluHijackObject.apply_hijack(EngineCore,
EngineCore.__init__,
EngineCore_MluHijack.__init__)
MluHijackObject.apply_hijack(EngineCore,
EngineCore.step,
EngineCore_MluHijack.step)
MluHijackObject.apply_hijack(EngineCore,
"response_remote_alloc_once",
EngineCore_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(EngineCore,
EngineCore.step_with_batch_queue,
EngineCore_MluHijack.step_with_batch_queue)
MluHijackObject.apply_hijack(EngineCoreProc,
EngineCoreProc.run_engine_core,
EngineCoreProc_MluHijack.run_engine_core)
MluHijackObject.apply_hijack(EngineCoreProc,
EngineCoreProc._process_input_queue,
EngineCoreProc_MluHijack._process_input_queue)

View File

@@ -0,0 +1,227 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.engine.core_client import (
EngineCoreClient,
InprocClient,
SyncMPClient,
AsyncMPClient,
DPAsyncMPClient,
DPLBAsyncMPClient,
)
from vllm.v1.engine import EngineCoreRequest
from vllm.config import VllmConfig
from vllm.v1.executor import Executor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class EngineCoreClient_MluHiack(EngineCoreClient):
@staticmethod
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disagg use DPAsyncMPClient instead of DPLBAsyncMPClient.
'''
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb or vllm_config.kv_transfer_config is not None:
# External load balancer - client per DP rank.
return DPAsyncMPClient(*client_args)
# Internal load balancer - client balances to all DP ranks.
return DPLBAsyncMPClient(*client_args)
'''
==================
End of MLU Hijack
==================
'''
return AsyncMPClient(*client_args)
class InprocClient_MluHiack(InprocClient):
def get_hfu_info(self, batch, input_len, output_len):
return self.engine_core.get_hfu_info(batch, input_len, output_len)
def get_latency(self):
return self.engine_core.get_latency()
def get_memory_usage(self):
return self.engine_core.get_memory_usage()
def recapture_model(
self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int,
):
return self.engine_core.recapture_model(
prefill_enable_mlugraph, batch_size, input_len
)
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
return self.engine_core.init_metric(
use_unchunk_sched, min_prefill_batch,
)
def start_scheduler_profile(self):
self.engine_core.start_scheduler_profile()
def stop_scheduler_profile(self):
self.engine_core.stop_scheduler_profile()
def response_remote_alloc_once(self) -> None:
self.engine_core.response_remote_alloc_once()
class SyncMPClient_MluHiack(SyncMPClient):
def get_hfu_info(self, batch, input_len, output_len):
try:
return self.call_utility("get_hfu_info", batch, input_len, output_len)
except Exception as e:
raise RuntimeError(f"Failed to get HFU info: {str(e)}")
def get_latency(self):
return self.call_utility("get_latency")
def get_memory_usage(self):
return self.call_utility("get_memory_usage")
def recapture_model(self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int):
return self.call_utility("recapture_model",
prefill_enable_mlugraph, batch_size, input_len)
def init_metric(self, use_unchunk_sched: bool, min_prefill_batch: int):
return self.call_utility("init_metric",
use_unchunk_sched,
min_prefill_batch)
def start_scheduler_profile(self):
self.call_utility("start_scheduler_profile")
def stop_scheduler_profile(self):
self.call_utility("stop_scheduler_profile")
def response_remote_alloc_once(self) -> None:
self.call_utility("response_remote_alloc_once")
class AsyncMPClient_MluHijack(AsyncMPClient):
async def start_scheduler_profile(self) -> None:
await self.call_utility_async("start_scheduler_profile")
async def stop_scheduler_profile(self) -> None:
await self.call_utility_async("stop_scheduler_profile")
async def response_remote_alloc_once(self) -> None:
await self.call_utility_async("response_remote_alloc_once")
class DPAsyncMPClient_MluHijack(DPAsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest):
'''
=============================
Modify by vllm_mlu
=============================
@brief: disagg need proxy to assign dp_rank
'''
if request.data_parallel_rank is not None:
# engines are already in rank order
return self.core_engines[request.data_parallel_rank]
'''
==================
End of MLU Hijack
==================
'''
return self.core_engine
MluHijackObject.apply_hijack(EngineCoreClient,
EngineCoreClient.make_async_mp_client,
EngineCoreClient_MluHiack.make_async_mp_client)
MluHijackObject.apply_hijack(InprocClient,
"get_hfu_info",
InprocClient_MluHiack.get_hfu_info)
MluHijackObject.apply_hijack(InprocClient,
"get_latency",
InprocClient_MluHiack.get_latency)
MluHijackObject.apply_hijack(InprocClient,
"get_memory_usage",
InprocClient_MluHiack.get_memory_usage)
MluHijackObject.apply_hijack(InprocClient,
"recapture_model",
InprocClient_MluHiack.recapture_model)
MluHijackObject.apply_hijack(InprocClient,
"init_metric",
InprocClient_MluHiack.init_metric)
MluHijackObject.apply_hijack(InprocClient,
"start_scheduler_profile",
InprocClient_MluHiack.start_scheduler_profile)
MluHijackObject.apply_hijack(InprocClient,
"stop_scheduler_profile",
InprocClient_MluHiack.stop_scheduler_profile)
MluHijackObject.apply_hijack(InprocClient,
"response_remote_alloc_once",
InprocClient_MluHiack.response_remote_alloc_once)
MluHijackObject.apply_hijack(SyncMPClient,
"get_hfu_info",
SyncMPClient_MluHiack.get_hfu_info)
MluHijackObject.apply_hijack(SyncMPClient,
"get_latency",
SyncMPClient_MluHiack.get_latency)
MluHijackObject.apply_hijack(SyncMPClient,
"get_memory_usage",
SyncMPClient_MluHiack.get_memory_usage)
MluHijackObject.apply_hijack(SyncMPClient,
"recapture_model",
SyncMPClient_MluHiack.recapture_model)
MluHijackObject.apply_hijack(SyncMPClient,
"init_metric",
SyncMPClient_MluHiack.init_metric)
MluHijackObject.apply_hijack(SyncMPClient,
"start_scheduler_profile",
SyncMPClient_MluHiack.start_scheduler_profile)
MluHijackObject.apply_hijack(SyncMPClient,
"stop_scheduler_profile",
SyncMPClient_MluHiack.stop_scheduler_profile)
MluHijackObject.apply_hijack(SyncMPClient,
"response_remote_alloc_once",
SyncMPClient_MluHiack.response_remote_alloc_once)
MluHijackObject.apply_hijack(AsyncMPClient,
"start_scheduler_profile",
AsyncMPClient_MluHijack.start_scheduler_profile)
MluHijackObject.apply_hijack(AsyncMPClient,
"stop_scheduler_profile",
AsyncMPClient_MluHijack.stop_scheduler_profile)
MluHijackObject.apply_hijack(AsyncMPClient,
"response_remote_alloc_once",
AsyncMPClient_MluHijack.response_remote_alloc_once)
MluHijackObject.apply_hijack(DPAsyncMPClient,
DPAsyncMPClient.get_core_engine_for_request,
DPAsyncMPClient_MluHijack.get_core_engine_for_request)

View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.engine.llm_engine import LLMEngine
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__engine__llm_engine__LLMEngine__get_hfu_info(self, batch, input_len, output_len):
return self.engine_core.get_hfu_info(batch, input_len, output_len)
def vllm__engine__llm_engine__LLMEngine__get_latency(self):
return self.engine_core.get_latency()
def vllm__engine__llm_engine__LLMEngine__get_memory_usage(self):
return self.engine_core.get_memory_usage()
def vllm__engine__llm_engine__LLMEngine__start_scheduler_profile(self):
self.engine_core.start_scheduler_profile()
def vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile(self):
self.engine_core.stop_scheduler_profile()
MluHijackObject.apply_hijack(LLMEngine,
"get_hfu_info",
vllm__engine__llm_engine__LLMEngine__get_hfu_info)
MluHijackObject.apply_hijack(LLMEngine,
"get_latency",
vllm__engine__llm_engine__LLMEngine__get_latency)
MluHijackObject.apply_hijack(LLMEngine,
"get_memory_usage",
vllm__engine__llm_engine__LLMEngine__get_memory_usage)
MluHijackObject.apply_hijack(LLMEngine,
"start_scheduler_profile",
vllm__engine__llm_engine__LLMEngine__start_scheduler_profile)
MluHijackObject.apply_hijack(LLMEngine,
"stop_scheduler_profile",
vllm__engine__llm_engine__LLMEngine__stop_scheduler_profile)

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.v1.executor.abstract import Executor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm__v1__executor__abstract__Executor__get_hfu_info(self, batch, input_len, output_len):
output = self.collective_rpc("get_hfu_info", args=([batch, input_len, output_len]))
return max(output)
def vllm__v1__executor__abstract__Executor__get_mm_encoder_latency(self):
output = self.collective_rpc("get_mm_encoder_latency")
return None if any(item is None for item in output) else max(output)
def vllm__v1__executor__abstract__Executor__get_latency(self):
output = self.collective_rpc("get_latency")
return max(output)
def vllm__v1__executor__abstract__Executor__get_memory_usage(self):
output = self.collective_rpc("get_memory_usage")
return output[0]
def vllm__v1__executor__abstract__Executor__recapture_model(
self, prefill_enable_mlugraph: bool, batch_size: int, input_len: int):
self.collective_rpc("recapture_model",
args=(prefill_enable_mlugraph, batch_size, input_len))
MluHijackObject.apply_hijack(
Executor,
"get_hfu_info",
vllm__v1__executor__abstract__Executor__get_hfu_info
)
MluHijackObject.apply_hijack(
Executor,
"get_latency",
vllm__v1__executor__abstract__Executor__get_latency
)
MluHijackObject.apply_hijack(
Executor,
"get_mm_encoder_latency",
vllm__v1__executor__abstract__Executor__get_mm_encoder_latency
)
MluHijackObject.apply_hijack(
Executor,
"get_memory_usage",
vllm__v1__executor__abstract__Executor__get_memory_usage
)
MluHijackObject.apply_hijack(
Executor,
"recapture_model",
vllm__v1__executor__abstract__Executor__recapture_model
)

View File

@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
from vllm_mlu.mlu_hijack_utils import MluHijackObject
class MultiprocExecutor_MluHijack(MultiprocExecutor):
def response_remote_alloc_once(self) -> None:
self.collective_rpc("response_remote_alloc_once", unique_reply_rank=self.output_rank)
MluHijackObject.apply_hijack(MultiprocExecutor,
"response_remote_alloc_once",
MultiprocExecutor_MluHijack.response_remote_alloc_once)

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import os
from collections import defaultdict
from typing import TYPE_CHECKING, Any
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
from vllm.v1.executor.ray_executor import RayDistributedExecutor, RayWorkerMetaData
from vllm.v1.executor.ray_utils import (
RayWorkerWrapper,
initialize_ray_cluster,
ray,
)
from vllm.utils.network_utils import (
get_distributed_init_method,
get_ip,
get_open_port,
)
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class RayDistributedExecutor_MluHijack(RayDistributedExecutor):
def _init_executor(self) -> None:
self.forward_dag: ray.dag.CompiledDAG | None = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: For MLU, avoid compiling NVIDIA's NCCL
'''
# For TPU or XPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu() or current_platform.is_xpu() or \
current_platform.is_out_of_tree():
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
'''
==================
End of MLU Hijack
==================
'''
assert self.uses_ray
initialize_ray_cluster(self.parallel_config)
placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection.
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
# Create the parallel GPU workers.
self._init_workers_ray(placement_group)
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.uses_sampler = self.vllm_config.model_config.runner_type != "pooling" and (
self.vllm_config.ec_transfer_config is None
or not self.vllm_config.ec_transfer_config.is_ec_producer
)
self.scheduler_output: SchedulerOutput | None = None
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
'''
=============================
Modify by vllm_mlu
=============================
@brief: use default cnperf config.
'''
runtime_env.update({
# use default cnperf config
"nsight": "default"
})
'''
==================
End of MLU Hijack
==================
'''
return ray_remote_kwargs
def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs):
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
# The driver dummy worker does not actually use any resources.
# It holds the resource for the driver worker.
self.driver_dummy_worker: RayWorkerWrapper | None = None
# The remaining workers are the actual ray actors.
self.workers: list[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs
)
# Create the workers.
bundle_indices: list[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
assert len(bundle_indices) == self.parallel_config.world_size, (
"VLLM_RAY_BUNDLE_INDICES must have the same size"
f" as the world size, but got {bundle_indices=} "
f"and {self.parallel_config.world_size=}"
)
assert len(set(bundle_indices)) == len(bundle_indices), (
"VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
f" but got {bundle_indices=}"
)
else:
# use the first N bundles that have GPU resources.
bundle_indices = []
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if bundle.get(current_platform.ray_device_key, 0):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[: self.parallel_config.world_size]
worker_metadata: list[RayWorkerMetaData] = []
driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices):
'''
=============================
Modify by vllm_mlu
=============================
@brief: support ray + cnperf-cli
'''
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs['runtime_env'].update({
"nsight": {
"o": f"cnperf_rank_{rank}",
"force_overwrite": "true"
}
})
if rank == 0:
ray_remote_kwargs['runtime_env'].update({
"nsight": {}
})
'''
==================
End of MLU Hijack
==================
'''
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if current_platform.ray_device_key == "GPU":
# NV+AMD GPUs, and Intel XPUs
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))
worker_ips = ray.get(
[
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
for each in worker_metadata
]
)
for each, ip in zip(worker_metadata, worker_ips):
each.ip = ip
logger.debug("workers: %s", worker_metadata)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
ip_counts: dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = item.ip
return 0 if ip == driver_ip else 1, ip_counts[ip], ip
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
sorted_worker_metadata = sorted(
worker_metadata, key=sort_by_driver_then_worker_ip
)
for i, item in enumerate(sorted_worker_metadata):
item.adjusted_rank = i
self.workers = [item.worker for item in sorted_worker_metadata]
rerank_mapping = {
item.created_rank: item.adjusted_rank for item in sorted_worker_metadata
}
self.collective_rpc("adjust_rank", args=(rerank_mapping,))
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = []
for worker in [self.driver_dummy_worker] + self.workers:
if worker is None:
# driver_dummy_worker can be None when using ray spmd worker.
continue
worker_node_and_gpu_ids.append(
ray.get(worker.get_node_and_gpu_ids.remote())
) # type: ignore[attr-defined]
node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
node_workers[node_id].append(i)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids = [int(x) for x in gpu_ids]
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
all_ips = set(worker_ips + [driver_ip])
n_ips = len(all_ips)
n_nodes = len(node_workers)
if n_nodes != n_ips:
raise RuntimeError(
f"Every node should have a unique IP address. Got {n_nodes}"
f" nodes with node ids {list(node_workers.keys())} and "
f"{n_ips} unique IP addresses {all_ips}. Please check your"
" network configuration. If you set `VLLM_HOST_IP`"
" environment variable, make sure it is unique for"
" each node."
)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [
{
current_platform.device_control_env_var: ",".join(
map(str, node_gpus[node_id])
),
}
for (node_id, _) in worker_node_and_gpu_ids
]
# Environment variables to copy from driver to workers
env_vars_to_copy = get_env_vars_to_copy(
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
additional_vars=set(current_platform.additional_env_vars).union(
self.ADDITIONAL_ENV_VARS
),
destination="workers",
)
# Copy existing env vars to each worker's args
for args in all_args_to_update_environment_variables:
# TODO: refactor platform-specific env vars
for name in env_vars_to_copy:
if name in os.environ:
args[name] = os.environ[name]
self._env_vars_for_all_workers = all_args_to_update_environment_variables
self.collective_rpc(
"update_environment_variables", args=(self._get_env_vars_to_be_updated(),)
)
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()
)
# Initialize the actual workers inside worker wrapper.
all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
all_kwargs.append(kwargs)
self.collective_rpc("init_worker", args=(all_kwargs,))
self.collective_rpc("init_device")
self.collective_rpc("load_model")
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
MluHijackObject.apply_hijack(
RayDistributedExecutor,
RayDistributedExecutor._configure_ray_workers_use_nsight,
RayDistributedExecutor_MluHijack._configure_ray_workers_use_nsight
)
MluHijackObject.apply_hijack(
RayDistributedExecutor,
RayDistributedExecutor._init_workers_ray,
RayDistributedExecutor_MluHijack._init_workers_ray
)
MluHijackObject.apply_hijack(
RayDistributedExecutor,
RayDistributedExecutor._init_executor,
RayDistributedExecutor_MluHijack._init_executor
)

View File

@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from dataclasses import dataclass
from typing_extensions import Self
import torch
from math import prod
from vllm.logger import init_logger
from vllm.utils.torch_utils import get_dtype_size
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
MLAAttentionSpec,
SlidingWindowSpec,
MambaSpec,
)
logger = init_logger(__name__)
@dataclass(frozen=True)
class MLUFullAttentionSpec(FullAttentionSpec):
@property
def type_id(self) -> str:
return f"mlu_full_attention_{self.block_size}_{self.page_size_bytes}"
@property
def cache_size_bytes(self) -> int:
return (
2
* self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@property
def scale_size_bytes(self) -> int:
scale_size_bytes = 0
if self.dtype in [torch.int8, torch.uint8]:
scale_size_bytes = (
2
* self.block_size
* self.num_kv_heads
* get_dtype_size(torch.float32)
)
return scale_size_bytes
@property
def page_size_bytes(self) -> int:
'''
=============================
Modify by vllm_mlu
=============================
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
'''
return self.cache_size_bytes + self.scale_size_bytes
'''
==================
End of MLU Hijack
==================
'''
@dataclass(frozen=True)
class MLUMLAAttentionSpec(MLAAttentionSpec):
# Use to record k_cache info for DSA indexer
index_head_dim: int = 0
index_n_heads: int = 0
@property
def type_id(self) -> str:
return f"mlu_mla_attention_{self.block_size}_{self.page_size_bytes}"
@property
def cache_size_bytes(self) -> int:
return (
self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@property
def scale_size_bytes(self) -> int:
scale_size_bytes = 0
if self.dtype in [torch.int8, torch.uint8]:
scale_size_bytes = (
self.block_size
* self.num_kv_heads
* get_dtype_size(torch.float32)
)
return scale_size_bytes
@property
def index_cache_size_bytes(self) -> int:
return (
self.block_size
* self.index_n_heads
* self.index_head_dim
* get_dtype_size(self.dtype)
)
@property
def page_size_bytes(self) -> int:
'''
=============================
Modify by vllm_mlu
=============================
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
@brief: caculate indexer cache size for deepseek v3.2
'''
return self.cache_size_bytes + self.scale_size_bytes + self.index_cache_size_bytes
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def merge(cls, specs: list[Self]) -> Self:
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be MLAAttentionSpec."
)
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
assert len(cache_dtype_str_set) == 1, (
"All attention layers in the same KV cache group must use the same "
"quantization method."
)
return cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
cache_dtype_str=cache_dtype_str_set.pop(),
index_head_dim=specs[0].index_head_dim,
index_n_heads=specs[0].index_n_heads,
)
@dataclass(frozen=True)
class MLUSlidingWindowSpec(SlidingWindowSpec):
@property
def type_id(self) -> str:
return f"mlu_sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
@property
def cache_size_bytes(self) -> int:
return (
2
* self.block_size
* self.num_kv_heads
* self.head_size
* get_dtype_size(self.dtype)
)
@property
def scale_size_bytes(self) -> int:
scale_size_bytes = 0
if self.dtype in [torch.int8, torch.uint8]:
scale_size_bytes = (
2
* self.block_size
* self.num_kv_heads
* get_dtype_size(torch.float32)
)
return scale_size_bytes
@property
def page_size_bytes(self) -> int:
'''
=============================
Modify by vllm_mlu
=============================
@brief: caculate kv_cache_scale size when kv_cache_dtype=int8
'''
return self.cache_size_bytes + self.scale_size_bytes
'''
==================
End of MLU Hijack
==================
'''
@property
def vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes(self) -> int:
page_size = sum(
prod(shape) * get_dtype_size(dtype)
for (shape, dtype) in zip(self.shapes, self.dtypes)
)
if self.page_size_padded is not None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: support qwen3-next
'''
# assert self.page_size_padded >= page_size
'''
==================
End of MLU Hijack
==================
'''
return self.page_size_padded
return page_size
MluHijackObject.apply_hijack(MambaSpec,
MambaSpec.page_size_bytes,
vllm__v1__kv_cache_interface__MambaSpec__page_size_bytes)

View File

@@ -0,0 +1,946 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import math
import torch
import triton
import triton.language as tl
import vllm
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample import rejection_sampler
from vllm.v1.sample.rejection_sampler import sample_recovered_tokens
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu._mlu_utils import *
from vllm_mlu import _mlu_ops as mlu_ops
PLACEHOLDER_TOKEN_ID: tl.constexpr = -1
GREEDY_TEMPERATURE: tl.constexpr = 0
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 128
'''
=============================
Modify by vllm_mlu
=============================
@brief:
- Limit maximum batch size due to NRAM memory constraints
- Add generate_recovered_uniform_probs function for tmo rejection sampler
'''
MAX_BATCH_SIZE = 65536
def generate_recovered_uniform_probs(
num_tokens: int,
vocab_size: int,
num_draft_tokens: list[int],
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
q = torch.empty(
(num_tokens, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
return q
'''
=============================
End of MLU Hijack
=============================
'''
def vllm__v1__sample__rejection_sampler__expand_batch_to_tokens(
x: torch.Tensor, # [batch_size]
cu_num_tokens: torch.Tensor, # [batch_size]
num_tokens: int,
replace_from: int = 0,
replace_to: int = 0,
) -> torch.Tensor:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
'''
=============================
Modify by vllm_mlu
=============================
'''
if batch_size > MAX_BATCH_SIZE:
raise ValueError(f"Rejection Sampler Not Supported: "
f"Batch size exceeds the maximum allowed value of {MAX_BATCH_SIZE}")
'''
==================
End of MLU Hijack
==================
'''
expanded_x = x.new_empty(num_tokens)
vllm__v1__sample__rejection_sampler__expand_kernel[(batch_size, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
)
return expanded_x
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def vllm__v1__sample__rejection_sampler__expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
MAX_NUM_TOKENS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0: # noqa: SIM108
'''
=============================
Modify by vllm_mlu
=============================
'''
# Ensure data types are consistent
start_idx = tl.full((), 0, tl.int64)
'''
==================
End of MLU Hijack
==================
'''
else:
'''
=============================
Modify by vllm_mlu
=============================
'''
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1).to(tl.int64)
'''
==================
End of MLU Hijack
==================
'''
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
@triton.jit
def vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_VOCAB: tl.constexpr = 2048,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
'''
=============================
Modify by vllm_mlu
=============================
'''
max_score = -float("inf")
max_index = 0
'''
==================
End of MLU Hijack
==================
'''
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Replace with block loop due to ngram limitations
'''
num_blocks = tl.cdiv(PADDED_VOCAB_SIZE, BLOCK_VOCAB)
for i in tl.range(0, num_blocks):
offset = i * BLOCK_VOCAB + tl.arange(0, BLOCK_VOCAB)
mask = offset < vocab_size
if NO_DRAFT_PROBS:
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + offset,
mask=mask,
other=0
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + offset,
mask=mask,
other=0
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + offset,
mask=mask,
other=0
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + offset,
mask=mask,
other=float("-inf"))
score = prob / q # Broadcasting elementwise
cur_max = tl.argmax(score, axis=0)
cur_score = score[cur_max]
cur_index = offset[cur_max]
# Manually maintain argmax.
if cur_score > max_score:
max_score = cur_score
max_index = cur_index
tl.store(output_token_ids_ptr + start_idx + pos, max_index)
'''
==================
End of MLU Hijack
==================
'''
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
"""
=============================
Modify by vllm_mlu
=============================
"""
def filter_with_acceptance_rate(output_token_ids, # [batch_size, max_spec_len + 1]
fixed_acceptance_rate):
"""
Filter speculative tokens based on a fixed acceptance rate using batch-level accept/reject decisions.
This function implements an adaptive acceptance rate control mechanism that maintains a target
acceptance rate over time through error compensation and PID-style adjustments.
Args:
output_token_ids (torch.Tensor): Input tensor of shape [batch_size, max_spec_len + 1]
where the first column contains base tokens and remaining columns contain speculative tokens
fixed_acceptance_rate (float or None): Target acceptance rate between 0.0 and 1.0
If None, returns input tensor unchanged
Returns:
torch.Tensor: Modified tensor where rejected batches have all speculative tokens
(columns 1 to max_spec_len) set to PLACEHOLDER_TOKEN_ID
Algorithm Flow:
1. **Initialization Phase**:
- Extract batch dimensions and device information
- Initialize static variables for tracking acceptance statistics:
* cumulative_error: Long-term error accumulation
* total_batches/accepted_batches: Global acceptance tracking
* acceptance_history: Sliding window for recent performance
* precision_adjustment: PID controller adjustment factor
* recent_adjustments: Error history for PID calculation
2. **Statistics Calculation**:
- Calculate global acceptance rate from all historical data
- Calculate sliding window acceptance rate from recent batches
- Compute combined error using weighted average of global and window errors
- Weight transitions from global-focused (early) to window-focused (later)
3. **PID Controller Adjustment** (after 50+ batches):
- Proportional term: Current error magnitude
- Integral term: Accumulated error over recent history
- Derivative term: Rate of error change
- Combines P, I, D terms to compute precision adjustment factor
- Limits adjustment range to prevent over-correction
4. **Error Correction**:
- Applies smooth nonlinear correction based on combined error magnitude
- Uses exponential decay mapping for gradual adjustment strength
- Handles boundary cases (0.0, 1.0, very low rates) specially
5. **Gap-based Adjustment**:
- Calculates difference between target and actual accepted batches
- Applies adaptive threshold-based corrections
- Uses exponential smoothing for adjustment strength
- Adjustment strength decreases as total batch count increases
6. **Random Perturbation** (after 100+ batches):
- Adds small random noise to prevent local minima
- Noise amplitude decreases over time for stability
7. **Batch Decision**:
- Generates random value and compares with adjusted acceptance rate
- Makes binary accept/reject decision for entire batch
8. **Token Modification**:
- If accepted: Leave all tokens unchanged
- If rejected: Set all speculative tokens (columns 1:) to PLACEHOLDER_TOKEN_ID
- This ensures token-level acceptance rate matches batch-level rate
9. **State Updates**:
- Update acceptance counters and history
- Update cumulative error using exponential moving average
- Prepare state for next function call
Key Features:
- **Batch-level consistency**: All samples in a batch share the same accept/reject fate
- **Adaptive control**: Uses multiple feedback mechanisms (global, windowed, PID)
- **Error compensation**: Corrects for deviations from target rate over time
- **Stability mechanisms**: Includes smoothing, limits, and perturbation for robustness
- **Token-level alignment**: Ensures token acceptance rate matches batch acceptance rate
Note: This function maintains internal state across calls through static variables,
so it will converge to the target acceptance rate over multiple invocations.
"""
if fixed_acceptance_rate is None:
return output_token_ids
else:
# Apply accept/reject decisions for the entire batch based on fixed_acceptance_rate
batch_size = output_token_ids.shape[0]
max_spec_len = output_token_ids.shape[1] - 1 # Get max_spec_len
device = output_token_ids.device
assert fixed_acceptance_rate >= 0 and fixed_acceptance_rate <= 1
# Use error compensation method to track global acceptance rate
# These are static variables that persist between calls
if not hasattr(filter_with_acceptance_rate, "cumulative_error"):
filter_with_acceptance_rate.cumulative_error = 0.0
if not hasattr(filter_with_acceptance_rate, "total_batches"):
filter_with_acceptance_rate.total_batches = 0
if not hasattr(filter_with_acceptance_rate, "accepted_batches"):
filter_with_acceptance_rate.accepted_batches = 0
if not hasattr(filter_with_acceptance_rate, "window_size"):
filter_with_acceptance_rate.window_size = 1000 # Sliding window size
if not hasattr(filter_with_acceptance_rate, "acceptance_history"):
filter_with_acceptance_rate.acceptance_history = [] # Track recent accept/reject history
if not hasattr(filter_with_acceptance_rate, "precision_adjustment"):
filter_with_acceptance_rate.precision_adjustment = 0.0 # Precision adjustment factor
if not hasattr(filter_with_acceptance_rate, "recent_adjustments"):
filter_with_acceptance_rate.recent_adjustments = [] # Recent adjustment history
if not hasattr(filter_with_acceptance_rate, "target_rate"):
filter_with_acceptance_rate.target_rate = fixed_acceptance_rate # Record target acceptance rate
else:
# If target acceptance rate changes, reset adjustment state
if filter_with_acceptance_rate.target_rate != fixed_acceptance_rate:
filter_with_acceptance_rate.precision_adjustment = 0.0
filter_with_acceptance_rate.recent_adjustments = []
filter_with_acceptance_rate.target_rate = fixed_acceptance_rate
# Update batch count
filter_with_acceptance_rate.total_batches += 1
# Calculate current global acceptance rate
global_rate = (filter_with_acceptance_rate.accepted_batches /
filter_with_acceptance_rate.total_batches if
filter_with_acceptance_rate.total_batches > 0 else 0.0)
# Calculate sliding window acceptance rate (focusing on recent performance)
filter_with_acceptance_rate.acceptance_history.append(0) # Default to reject
if len(filter_with_acceptance_rate.acceptance_history) > filter_with_acceptance_rate.window_size:
filter_with_acceptance_rate.acceptance_history.pop(0) # Remove oldest record
window_rate = sum(filter_with_acceptance_rate.acceptance_history) / len(filter_with_acceptance_rate.acceptance_history)
# Enhance precision for small batches - use smoother weight function
batch_weight_factor = 1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 30.0) # Exponential smooth transition
# Dynamically adjust error weights: rely more on global error for fewer batches,
# more on sliding window error as batch count increases
window_size = len(filter_with_acceptance_rate.acceptance_history)
window_significance = min(window_size / 100.0, 0.9) # Window significance depends on historical data volume
window_weight = window_significance * batch_weight_factor
global_weight = 1.0 - window_weight
# Consider both global error and window error
combined_error = (global_weight * (global_rate - fixed_acceptance_rate) +
window_weight * (window_rate - fixed_acceptance_rate))
# Update precision adjustment factor - use PID controller style adjustment
if filter_with_acceptance_rate.total_batches > 50:
# Only perform precision adjustment when there's enough data
current_error = global_rate - fixed_acceptance_rate
# Save recent adjustment history
filter_with_acceptance_rate.recent_adjustments.append(current_error)
if len(filter_with_acceptance_rate.recent_adjustments) > 20: # Keep recent 20 errors
filter_with_acceptance_rate.recent_adjustments.pop(0)
# PID controller parameters
kp = 0.05 # Proportional coefficient
ki = 0.001 # Integral coefficient
kd = 0.01 # Derivative coefficient
# Proportional term - current error
p_term = current_error
# Integral term - accumulated error
i_term = sum(filter_with_acceptance_rate.recent_adjustments)
# Derivative term - error change rate
d_term = 0.0
if len(filter_with_acceptance_rate.recent_adjustments) >= 2:
d_term = filter_with_acceptance_rate.recent_adjustments[-1] - filter_with_acceptance_rate.recent_adjustments[-2]
# Calculate PID adjustment
pid_adjustment = kp * p_term + ki * i_term + kd * d_term
# Update precision adjustment factor
filter_with_acceptance_rate.precision_adjustment = pid_adjustment
# Limit adjustment factor range to prevent over-adjustment
max_adjustment = 0.02 + 0.03 * (1.0 - math.exp(-filter_with_acceptance_rate.total_batches / 500.0))
filter_with_acceptance_rate.precision_adjustment = max(-max_adjustment, min(max_adjustment, filter_with_acceptance_rate.precision_adjustment))
# Calculate acceptance rate correction factor
error_magnitude = abs(combined_error)
correction_factor = 1.0
# Use more refined error correction logic - use smooth nonlinear correction function
if error_magnitude > 0.0005: # Correct even smaller errors
# Use smooth correction function instead of piecewise function
base_strength = 2.0
error_scale = 1.0 - math.exp(-error_magnitude * 50.0) # Exponential decay mapping to [0,1]
correction_strength = base_strength + error_scale * 1.5 # Range from 2.0 to 3.5
# Smooth correction
sign = 1 if combined_error > 0 else -1
correction_factor = 1.0 + (correction_strength * error_magnitude * sign)
# Handle boundary cases to avoid division by zero
if correction_factor == 0.0:
correction_factor = 1.0
# Apply correction factor
adjusted_rate = max(0.0, min(1.0, fixed_acceptance_rate * (1.0 / correction_factor)))
# Apply precision adjustment factor
adjusted_rate = max(0.0, min(1.0, adjusted_rate - filter_with_acceptance_rate.precision_adjustment))
# More precise boundary case handling
if fixed_acceptance_rate > 0 and fixed_acceptance_rate < 0.05:
if filter_with_acceptance_rate.total_batches % int(1/fixed_acceptance_rate) == 0:
adjusted_rate = 1.0 # Periodically force accept to ensure accuracy in low acceptance rate scenarios
# If fixed_acceptance_rate is 0, directly reject
elif fixed_acceptance_rate == 0.0:
adjusted_rate = 0.0
# If fixed_acceptance_rate is 1, directly accept
elif fixed_acceptance_rate == 1.0:
adjusted_rate = 1.0
# Make precise adjustments for cases with large remaining errors
target_accepted = int(filter_with_acceptance_rate.total_batches * fixed_acceptance_rate + 0.5) # Round to nearest
actual_accepted = filter_with_acceptance_rate.accepted_batches
acceptance_gap = target_accepted - actual_accepted
# More aggressive gap adjustment strategy - use adaptive threshold and smooth adjustment
gap_relative = abs(acceptance_gap) / max(1, filter_with_acceptance_rate.total_batches)
gap_threshold = max(1, int(filter_with_acceptance_rate.total_batches * 0.005)) # Smaller dynamic threshold, at least 1
# Dynamically adjust acceptance rate based on the gap
if abs(acceptance_gap) >= gap_threshold: # Use dynamic threshold
# Use smooth adjustment strategy
if acceptance_gap > 0: # Need to accept more
# Use exponential function for smooth adjustment
gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1]
# Adjustment strength decreases as total batch count increases
strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0)
boost_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count
adjusted_rate = min(1.0, adjusted_rate + (1.0 - adjusted_rate) * boost_factor)
else: # Accepted too many, need to reject
# Use exponential function for smooth adjustment
gap_importance = 1.0 - math.exp(-gap_relative * 50.0) # Map to [0,1]
# Adjustment strength decreases as total batch count increases
strength_factor = math.exp(-filter_with_acceptance_rate.total_batches / 1000.0)
reduction_factor = gap_importance * (0.2 + 0.8 * strength_factor) # Range from 0 to 1, decreases with total batch count
adjusted_rate = max(0.0, adjusted_rate * (1.0 - reduction_factor))
# Add small random perturbation in fixed intervals to enhance convergence
if 0.01 < adjusted_rate < 0.99 and filter_with_acceptance_rate.total_batches > 100:
# Random perturbation amplitude decreases as batch count increases
noise_amplitude = 0.01 * math.exp(-filter_with_acceptance_rate.total_batches / 500.0)
noise = (torch.rand(1, device=device).item() * 2 - 1) * noise_amplitude
adjusted_rate = max(0.0, min(1.0, adjusted_rate + noise))
# Generate a random number to decide whether to accept the current batch
random_value = torch.rand(1, device=device).item()
accept_batch = random_value < adjusted_rate
# Set some tokens to PLACEHOLDER_TOKEN_ID to achieve specified acceptance rate
# Support max_spec_len > 1 cases
if accept_batch:
# Accept batch - don't modify token_ids
filter_with_acceptance_rate.accepted_batches += 1
filter_with_acceptance_rate.acceptance_history[-1] = 1 # Update the most recent acceptance status
else:
# Reject batch - set all speculative tokens (except first column) to PLACEHOLDER_TOKEN_ID
# This ensures token-level acceptance rate matches batch-level acceptance rate
output_token_ids[:, 1:] = PLACEHOLDER_TOKEN_ID
# Note: acceptance rate calculation is still based on entire batch accept/reject, no modification needed
# But we can add a comment explaining how actual token-level acceptance rate is calculated
# Actual token-level acceptance rate = 1 - (number of PLACEHOLDER_TOKEN_ID in output_token_ids / max_spec_len)
# Update cumulative error - use exponential moving average for smoother error adjustment
actual_rate = filter_with_acceptance_rate.accepted_batches / filter_with_acceptance_rate.total_batches
# Use EMA to smooth error updates - use adaptive EMA coefficient
alpha = 0.05 * math.exp(-filter_with_acceptance_rate.total_batches / 200.0) + 0.01 # EMA coefficient gradually decreases over time
filter_with_acceptance_rate.cumulative_error = (alpha * (actual_rate - fixed_acceptance_rate) +
(1 - alpha) * filter_with_acceptance_rate.cumulative_error)
return output_token_ids
"""
=============================
End of MLU Hijack
=============================
"""
def vllm__v1__sample__rejection_sampler__rejection_sample(
# [num_tokens]
draft_token_ids: torch.Tensor,
# [batch_size]
num_draft_tokens: list[int],
max_spec_len: int,
# [batch_size]
cu_num_draft_tokens: torch.Tensor,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
'''
=============================
Modify by vllm_mlu
=============================
@brief: use tmo rejection_sample for all random sampling requests
'''
fixed_acceptance_rate = VLLM_MTP_FIXED_ACCEPTANCE_RATE
use_fusion_kernel = (sampling_metadata.all_random
and max_spec_len == 1
and (num_draft_tokens is not None
and 0 not in num_draft_tokens))
if use_fusion_kernel:
# All data is random, use tmo rejection_sample
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_rand = vllm__v1__sample__rejection_sampler__generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# generate random probs for recovered tokens
uniform_probs = generate_recovered_uniform_probs(
num_tokens,
vocab_size,
num_draft_tokens,
sampling_metadata,
device,
)
# num_draft_tokens need to be a tensor
num_draft_tokens_tensor = torch.tensor(num_draft_tokens, dtype=torch.int32, device=device)
# tmo rejection_sample dtype need to be int32
bonus_token_ids = bonus_token_ids.to(torch.int32)
draft_token_ids = draft_token_ids.to(torch.int32)
# use tmo rejection_sample
output_token_ids = mlu_ops.rejection_sample(
draft_token_ids,
num_draft_tokens_tensor,
cu_num_draft_tokens,
draft_probs,
target_probs,
bonus_token_ids,
uniform_rand,
uniform_probs,
max_spec_len,
high_acc=True # for now, only support high_acc
).view(batch_size, max_spec_len + 1)
if fixed_acceptance_rate is not None:
# set all speculative tokens to placeholder token
output_token_ids[:, 1:] = 0
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
return output_token_ids
'''
=============================
End of MLU Hijack
=============================
'''
# Create output buffer.
output_token_ids = torch.full(
(batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID,
dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids.
device=device,
)
if sampling_metadata.all_greedy:
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
has_acceptance_rate=fixed_acceptance_rate is not None,
)
if sampling_metadata.all_greedy:
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
return output_token_ids
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = vllm__v1__sample__rejection_sampler__generate_uniform_probs(
num_tokens,
num_draft_tokens,
sampling_metadata.generators,
device,
)
# Sample recovered tokens for each position.
# [num_tokens]
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
sampling_metadata,
device,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add fixed acceptance rate check
'''
# Rejection sampling for random sampling requests.
vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel[(batch_size, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
has_acceptance_rate=fixed_acceptance_rate is not None,
)
output_token_ids = filter_with_acceptance_rate(output_token_ids, fixed_acceptance_rate)
'''
==================
End of MLU Hijack
==================
'''
return output_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
has_acceptance_rate: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exit for greedy sampling requests.
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add accept rate check, always accept if has_acceptance_rate is True
'''
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob or has_acceptance_rate:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
'''
=============================
End of MLU Hijack
=============================
'''
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Check whether to accept bonus token through acceptance_rate_ptr
'''
# If has acceptance rate, all tokens are accepted
if has_acceptance_rate:
rejected = False
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
'''
==================
End of MLU Hijack
==================
'''
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@triton.jit(do_not_specialize=["max_spec_len"])
def vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
max_spec_len,
has_acceptance_rate: tl.constexpr,
):
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if has_acceptance_rate:
rejected = False
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
def vllm__v1__sample__rejection_sampler__generate_uniform_probs(
num_tokens: int,
num_draft_tokens: list[int],
generators: dict[int, torch.Generator],
device: torch.device,
) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
if available.
This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
the requests with their own seeds will use the provided `torch.Generator`
for reproducibility. The samples for the other requests will be generated
without a seed.
Args:
num_tokens: int
Total number of tokens.
num_draft_tokens: List[List[int]]
Number of draft tokens per request.
generators: Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device: torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand: torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
# NOTE(woosuk): We deliberately use float64 instead of float32 here
# because when using float32, there's a non-negligible chance that
# uniform_prob is sampled to be exact 0.0 as reported in
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
'''
=============================
Modify by vllm_mlu
=============================
@brief: Changed torch.float64 to torch.float32
'''
uniform_probs = torch.rand(
(num_tokens,),
dtype=torch.float32,
device=device,
)
'''
==================
End of MLU Hijack
==================
'''
start_idx = 0
for req_idx, n in enumerate(num_draft_tokens):
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if n == 0:
continue
end_idx = start_idx + n
generator = generators.get(req_idx)
if generator is not None:
uniform_probs[start_idx:end_idx].uniform_(generator=generator)
start_idx = end_idx
return uniform_probs
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.generate_uniform_probs,
vllm__v1__sample__rejection_sampler__generate_uniform_probs)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.expand_batch_to_tokens,
vllm__v1__sample__rejection_sampler__expand_batch_to_tokens)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.expand_kernel,
vllm__v1__sample__rejection_sampler__expand_kernel)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.sample_recovered_tokens_kernel,
vllm__v1__sample__rejection_sampler__sample_recovered_tokens_kernel)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.rejection_sample,
vllm__v1__sample__rejection_sampler__rejection_sample)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.rejection_random_sample_kernel,
vllm__v1__sample__rejection_sampler__rejection_random_sample_kernel)
MluHijackObject.apply_hijack(rejection_sampler,
rejection_sampler.rejection_greedy_sample_kernel,
vllm__v1__sample__rejection_sampler__rejection_greedy_sample_kernel)

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.config.model import LogprobsMode
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS
from vllm_mlu._mlu_utils import *
from vllm_mlu import _mlu_ops as mlu_ops
"""
@brief: use tmo random_sample
"""
def mlu_random_sample(
probs: torch.Tensor,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
is_gumbel_max = True
return mlu_ops.random_sample(probs, is_gumbel_max, generators).view(-1)
class MluSampler(Sampler):
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
logprobs_mode_override: LogprobsMode | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Sample logits based on sampling metadata.
The various logits processing functions called in this method
may update the logits tensor in-place.
"""
logprobs_mode = logprobs_mode_override or self.logprobs_mode
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs
assert sampling_metadata.temperature is not None
"""
=============================
Modify by vllm_mlu
=============================
@brief: use tmo topk_topp_sampler to sample.
"""
use_tmo = (sampling_metadata.top_k is not None) or (sampling_metadata.top_p is not None)
if use_tmo:
batch_size, vocab_size = logits.shape
index_in = torch.arange(vocab_size, dtype=torch.int32, device=logits.device)
(
logits_out,
sorted_logits_out,
index_out,
true_select_len,
) = mlu_ops.apply_topkp_v2(
logits,
index_in,
sampling_metadata.temperature,
None,
sampling_metadata.top_k.to(torch.int32) if sampling_metadata.top_k is not None else None,
sampling_metadata.top_p,
)
processed_logprobs = None
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
probs = logits_out.softmax(dim=-1, dtype=torch.float32)
random_sampled = mlu_random_sample(probs, sampling_metadata.generators)
else:
# Apply temperature.
logits = self.apply_temperature(
logits, sampling_metadata.temperature, sampling_metadata.all_random
)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
"""
=================
End of MLU Hijack
=================
"""
if greedy_sampled is None:
return random_sampled, processed_logprobs
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
return sampled, processed_logprobs

View File

@@ -0,0 +1,530 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import List, Optional, Any
import copy
import torch
import torch.nn.functional as F
from vllm.config.vllm import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, logger
from vllm.distributed.communication_op import tensor_model_parallel_all_gather_into_list
from vllm.distributed import (
get_logits_tp_world_size,
get_logits_tp_group,
get_tensor_model_parallel_world_size,
)
from vllm_mlu.v1.attention.backends.flash_attn import pad_attn_metadata
from vllm_mlu.v1.attention.backends.mla.flashmla import FlashMLAMetadataBuilder
from vllm_mlu.v1.attention.backends.utils import (
MLUCommonAttentionMetadata, COMMON_METADATA_STR)
from vllm_mlu._mlu_utils import *
from vllm_mlu.v1.attention.backends.utils import MLUInferMode
from vllm_mlu.mlu_forward_context import MLUDPMetadata
from vllm_mlu.v1.spec_decode.eagle import MluEagleProposer
from vllm_mlu.model_executor.models.dp_utils import (
enable_data_parallel,
DataParallelRuntimeParams
)
class DPMluEagleProposer(MluEagleProposer):
def get_logits_batch_sizes(self, batch_size: int) -> Optional[List[int]]:
tp_world_size, logits_batch_sizes = get_logits_tp_world_size(), None
if tp_world_size != get_tensor_model_parallel_world_size():
tp_tensor = torch.tensor([batch_size]).to(self.runner.device)
outputs = tensor_model_parallel_all_gather_into_list(tp_tensor, get_logits_tp_group())
# Convert device tensor to host list
outputs = torch.cat(outputs).tolist()
logits_batch_sizes = [outputs[i] for i in range(tp_world_size)]
return logits_batch_sizes
def propose_ds_execute_dummy_batch(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
dp_params: DataParallelRuntimeParams,
) -> tuple[torch.Tensor, torch.Tensor]:
# num_scheduled_tokens
num_tokens = target_token_ids.shape[0]
input_ids = self.input_ids[:num_tokens]
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
# always skip attn compute
attn_metadata: Optional[dict[str, Any]] = None
# Get graph capture related infomation for deepseek model.
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
hidden_states=target_hidden_states,
intermediate_tensors=None,
inputs_embeds=None,
dp_params=dp_params,
)
if dp_params is not None:
dp_params.logits_batch_split_list = self.get_logits_batch_sizes(num_tokens)
_ = self.model.compute_logits(hidden_states, dp_params=dp_params)
if self.num_speculative_tokens == 1:
return
'''
=============================
Modify by vllm_mlu
@brief: support k > 1, need run draft model k-1 times
=============================
'''
# support k > 1
for _ in range(self.num_speculative_tokens - 1):
new_dp_params = self.runner._get_data_parallel_metadata(
num_tokens, num_tokens, True, [1] * num_tokens)
with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
hidden_states=target_hidden_states,
intermediate_tensors=None,
inputs_embeds=None,
dp_params=new_dp_params,
)
_ = self.model.compute_logits(hidden_states, dp_params=new_dp_params)
'''
=============================
End of MLU Hijack
=============================
'''
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: torch.Tensor | None,
common_attn_metadata: MLUCommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor,
# [num_tokens]
token_indices: torch.Tensor,
whole_block_table: torch.Tensor,
main_model_dp_params: Optional[DataParallelRuntimeParams] = None,
time_markers: List =[],
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
hidden_states_indices = last_token_indices
assert self.runner is not None
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use full graph with draft model and pad batch_size for dp
'''
dp_group_max_token_num = max(main_model_dp_params.token_split_list)
if dp_group_max_token_num <= self.vllm_config.compilation_config.max_cudagraph_capture_size:
batch_descriptor_num_tokens = self.vllm_config.pad_for_cudagraph(dp_group_max_token_num)
captured_already = True
else:
batch_descriptor_num_tokens = num_tokens
captured_already = False
# Determine if we can use full graph
decode_only = all(not prefill for prefill in main_model_dp_params.dp_is_prefill)
# FIXME(wangchao2): disable mtp graph for ds3.2 with dp fow now(core dump)
is_dsv32 = self.vllm_config.model_config.hf_config.model_type == "deepseek_v32"
use_full_graph = (self.use_cuda_graph
and decode_only and captured_already and not is_dsv32)
if (self.use_cuda_graph and decode_only and not use_full_graph and not is_dsv32):
logger.warning_once(
f"Select MLU-V1 Full-MLUGraph mode with drafter, however running in " +
f"eager mode: decode_only={decode_only}, captured_already={captured_already}, " +
f"num_tokens={num_tokens}."
)
cudagraph_runtime_mode = CUDAGraphMode.FULL if use_full_graph else CUDAGraphMode.NONE
batch_descriptor = BatchDescriptor(
num_tokens=batch_descriptor_num_tokens,
uniform_decode=True,
)
# dp pad batch_size
if use_full_graph:
K = self.num_speculative_tokens
num_input_tokens = batch_descriptor_num_tokens
padded_batch_size = num_input_tokens // (K + 1)
else:
padded_batch_size = batch_size
num_input_tokens = num_tokens
# change attn metadata num_actual_tokens
attn_metadata.num_actual_tokens = num_input_tokens
common_attn_metadata_copy = None
# copy common_attn_metadata when k>1 for draft model,
# because dp pad batch_size will change common_attn_metadata
if self.num_speculative_tokens > 1:
common_attn_metadata_copy = copy.deepcopy(common_attn_metadata)
# pad attn metadata
if use_full_graph and enable_data_parallel() and num_input_tokens != num_tokens:
assert self.runner is not None
# Update attention metadata.
pad_attn_metadata(
attn_metadata,
common_attn_metadata,
whole_block_table,
self.runner,
num_tokens,
num_input_tokens,
batch_size,
padded_batch_size,
)
# Update input ids, pad with 0 if necessary.
token_pad_size = num_input_tokens - num_tokens
assert token_pad_size >= 0
# Update target hidden states, pad with zeros if necessary.
if token_pad_size > 0:
target_hidden_states = F.pad(
target_hidden_states,
(0, 0, 0, token_pad_size),
value=0.0
)
# Update positions, pad with zeros if necessary.
if token_pad_size > 0:
target_positions = F.pad(
target_positions,
(0, token_pad_size),
value=0
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata
# copy inputs to buffer for cudagraph
self.positions[:num_input_tokens] = target_positions
self.hidden_states[:num_input_tokens] = target_hidden_states
kwargs = {} if main_model_dp_params is None else {"dp_params": main_model_dp_params}
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
start = torch.mlu.Event(enable_timing=True)
start.record()
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
batch_descriptor=batch_descriptor if use_full_graph else None,
cudagraph_runtime_mode=cudagraph_runtime_mode):
if use_full_graph:
ret_hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
intermediate_tensors=None,
inputs_embeds=None,
is_running_drafter=True,
**kwargs,
)
else:
ret_hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
**kwargs,
)
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
end = torch.mlu.Event(enable_timing=True)
end.record()
time_markers.append([start, end])
if self.method == "mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
'''
=============================
End of MLU Hijack
=============================
'''
if main_model_dp_params is not None:
# Ensure main_model_dp_params has required attribute before assignment
if hasattr(main_model_dp_params, 'logits_batch_split_list'):
main_model_dp_params.logits_batch_split_list = self.get_logits_batch_sizes(batch_size)
else:
raise AttributeError("dp_params must have 'logits_batch_split_list' attribute")
sample_hidden_states = last_hidden_states[hidden_states_indices]
logits = self.model.compute_logits(sample_hidden_states, dp_params=main_model_dp_params)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
if self.uses_mrope:
positions = target_positions[:, last_token_indices]
else:
positions = target_positions[last_token_indices]
'''
=============================
Modify by vllm_mlu
=============================
'''
hidden_states = last_hidden_states[hidden_states_indices]
'''
=============================
End of MLU Hijack
=============================
'''
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
input_batch_size = batch_size
if common_attn_metadata.infer_mode != MLUInferMode.DECODE_ONLY:
seq_lens_cpu = torch.ones(input_batch_size, dtype=torch.int32,)
cu_num_tokens = torch.cumsum(seq_lens_cpu, dim=0)
query_start_loc_cpu = torch.empty(input_batch_size + 1, dtype=torch.int32)
query_start_loc_cpu[0] = 0
query_start_loc_cpu[1:] = cu_num_tokens
seq_start_loc_cpu = self.arange[:input_batch_size + 1]
common_attn_metadata_k = MLUCommonAttentionMetadata.build(
query_start_loc=query_start_loc_cpu.to(self.device, non_blocking=True),
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens_cpu.to(self.device, non_blocking=True),
seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
seq_start_loc=seq_start_loc_cpu.to(self.device, non_blocking=True),
is_start_loc_match=False, # not prefill
max_query_len=1,
num_actual_tokens=input_batch_size,
num_input_tokens=input_batch_size,
num_speculative_tokens=self.num_speculative_tokens,
has_prefill_reqs=common_attn_metadata.infer_mode == MLUInferMode.CHUNKED,
)
else:
common_attn_metadata_k = common_attn_metadata_copy
common_attn_metadata_k.num_actual_tokens = batch_size
common_attn_metadata_k.num_input_tokens = batch_size
common_attn_metadata_k.max_query_len = 1
common_attn_metadata_k.query_start_loc = self.arange[: batch_size + 1]
common_attn_metadata_k.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[: batch_size + 1]
).clone()
# In padded drafter batch, we need to adjust the sequence lengths
# to remove the "padding" (i.e. rejected tokens).
# Only apply this adjustment when we have rejected tokens
# (i.e., not the first proposal).
for token_index in range(self.num_speculative_tokens - 1):
'''
=============================
Modify by vllm_mlu
=============================
@brief: get dp_params for draft model
'''
# dp_params for draft model
if main_model_dp_params is not None:
dp_params = self.runner._get_data_parallel_metadata(
input_batch_size,
input_batch_size,
common_attn_metadata.is_decode_only,
[1] * input_batch_size
)
kwargs = {} if main_model_dp_params is None else {"dp_params": dp_params}
'''
=============================
End of MLU Hijack
=============================
'''
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
if self.uses_mrope:
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length.
# Since it is complex to remove such requests from the batch,
# we keep them in the batch but adjust the position ids
# and slot mappings to avoid the
# out-of-range access during the model execution.
# The draft tokens generated with this adjustment
# should be ignored.
exceeds_max_model_len = positions[0] >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(
exceeds_max_model_len.unsqueeze(0),
torch.zeros_like(positions),
positions,
)
else:
positions += 1
exceeds_max_model_len = positions >= self.max_model_len
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# For data integrity when async scheduling, we shouldn't use in place
# operations in case they are modified in next step's `prepare_input`
# of main model.
# Increment the sequence lengths.
common_attn_metadata_k.seq_lens += 1
# This is an out-of-place operation to avoid modifying the original tensor.
common_attn_metadata_k.seq_lens_cpu = common_attn_metadata_k.seq_lens_cpu + 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata_k.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata_k.num_computed_tokens_cpu = (
common_attn_metadata_k.seq_lens_cpu - 1
)
# Compute the slot mapping.
if self.uses_mrope:
# all dimensions of positions are the same
block_numbers = clamped_positions[0] // self.block_size
else:
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata_k.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1)
)
block_ids = block_ids.view(-1)
if self.uses_mrope:
common_attn_metadata_k.slot_mapping = (
block_ids * self.block_size + clamped_positions[0] % self.block_size
)
else:
common_attn_metadata_k.slot_mapping = (
block_ids * self.block_size + clamped_positions % self.block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata_k.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID
)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata_k, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
per_layer_attn_metadata[COMMON_METADATA_STR] = common_attn_metadata_k
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if self.supports_mm_inputs:
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(input_ids)
input_ids = None
inputs_embeds = self.inputs_embeds[:input_batch_size]
else:
input_ids = self.input_ids[:input_batch_size]
inputs_embeds = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: record latency
'''
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
start = torch.mlu.Event(enable_timing=True)
start.record()
'''
=============================
End of MLU Hijack
=============================
'''
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
**kwargs,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
end = torch.mlu.Event(enable_timing=True)
end.record()
time_markers.append([start, end])
'''
=============================
End of MLU Hijack
=============================
'''
hidden_states = last_hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
dp_params=dp_params)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.logger import init_logger
from vllm.v1.worker.block_table import BlockTable
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
class BlockTable_MluHijack(BlockTable):
def __init__(
self,
block_size: int,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
dcp_kv_cache_interleave_size: int,
):
"""
Args:
block_size: Block size used for KV cache memory allocation
max_num_reqs: Maximum number of concurrent requests supported.
max_num_blocks_per_req: Maximum number of blocks per request.
max_num_batched_tokens: Maximum number of tokens in a batch.
pin_memory: Whether to pin memory for faster GPU transfers.
device: Target device for the block table.
kernel_block_size: The block_size of underlying attention kernel.
Will be the same as `block_size` if `block_size` is supported
by the attention kernel.
"""
self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
if kernel_block_size == block_size:
# Standard case: allocation and computation use same block size
# No block splitting needed, direct mapping
self.block_size = block_size
self.blocks_per_kv_block = 1
self.use_hybrid_blocks = False
else:
# Hybrid case: allocation block size differs from kernel block size
# Memory blocks are subdivided to match kernel requirements
# Example: 32-token memory blocks with 16-token kernel blocks
# → Each memory block corresponds to 2 kernel blocks
if block_size % kernel_block_size != 0:
raise ValueError(
f"kernel_block_size {kernel_block_size} must divide "
f"kv_manager_block_size size {block_size} evenly"
)
self.block_size = kernel_block_size
self.blocks_per_kv_block = block_size // kernel_block_size
self.use_hybrid_blocks = True
self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block
self.block_table = self._make_buffer(
self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32
)
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
'''
=============================
Modify by vllm_mlu
=============================
@brief: change slot_mapping dtype for int64 to int32
'''
self.slot_mapping = self._make_buffer(
self.max_num_batched_tokens, dtype=torch.int32
)
'''
==================
End of MLU Hijack
==================
'''
if self.use_hybrid_blocks:
self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape(
1, -1
)
else:
self._kernel_block_arange = None
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
MluHijackObject.apply_hijack(
BlockTable,
BlockTable.__init__,
BlockTable_MluHijack.__init__
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def split_decodes_and_prefills(self):
decodes = 0
prefills = 0
for i, req_id in enumerate(self.req_ids):
req_index = self.req_id_to_index.get(req_id)
num_prompt_tokens = self.num_prompt_tokens[req_index]
num_computed_tokens = self.num_computed_tokens_cpu[req_index]
if num_computed_tokens < num_prompt_tokens:
prefills += 1
else:
decodes += 1
return decodes, prefills
MluHijackObject.apply_hijack(InputBatch,
"split_decodes_and_prefills",
split_decodes_and_prefills)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,638 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
"""A GPU worker class."""
import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_tp_group, get_pp_group
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.utils.mem_constants import GiB_bytes
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm_mlu.model_executor.warmup.kernel_warmup import kernel_warmup
from vllm_mlu.profiler.mlu_profiler import MluProfilerWrapper
from vllm_mlu.utils import MemorySnapshot, memory_profiling
from vllm_mlu._mlu_utils import VLLM_DUMP_MLU_INFO_EN
from vllm_mlu.device_allocator.cnmem import CnMemAllocator
from vllm_mlu.v1.worker.mlu_quant import MLUWorkerQuant
from vllm_mlu.v1.worker.gpu_model_runner import MLUModelRunner
from vllm_mlu.v1.worker.dp_gpu_model_runner import DPMLUModelRunner
logger = init_logger(__name__)
class MLUWorker(Worker, MLUWorkerQuant):
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
WorkerBase.__init__(self, vllm_config=vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker)
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules
init_cached_hf_modules()
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
worker_name = f"{vllm_config.instance_id}-rank-{self.rank}"
logger.info(
"Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir,
)
logger.debug(
"Profiler config: record_shapes=%s,"
"profile_memory=%s,with_stack=%s,with_flops=%s",
envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
envs.VLLM_TORCH_PROFILER_WITH_STACK,
envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.MLU,
],
record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES,
profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY,
with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK,
with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True
),
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
self.profiler = MluProfilerWrapper()
else:
self.profiler = None
def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.mlu.mem_get_info()[0]
# Save the buffers before level 2 sleep
if level == 2:
model = self.model_runner.model
self._sleep_saved_buffers = {
name: buffer.cpu().clone() for name, buffer in model.named_buffers()
}
allocator = CnMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.mlu.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
logger.info(
"Sleep mode freed %.2f GiB memory, "
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CnMemAllocator.get_instance()
allocator.wake_up(tags)
# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CnMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process."
)
context = allocator.use_memory_pool(tag=tag)
else:
context = nullcontext()
return context
def init_device(self):
if self.device_config.device.type == "mlu":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("CNCL_ASYNC_ERROR_HANDLING", None)
# if (
# self.parallel_config.data_parallel_size > 1
# and self.parallel_config.data_parallel_size_local > 0
# and self.parallel_config.distributed_executor_backend
# not in ["ray", "external_launcher"]
# and self.vllm_config.parallel_config.data_parallel_backend != "ray"
# ):
# # Use local DP rank if available, otherwise use global DP rank.
# dp_local_rank = self.parallel_config.data_parallel_rank_local
# if dp_local_rank is None:
# dp_local_rank = self.parallel_config.data_parallel_rank
# tp_pp_world_size = (
# self.parallel_config.pipeline_parallel_size
# * self.parallel_config.tensor_parallel_size
# )
# # DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
# self.local_rank += dp_local_rank * tp_pp_world_size
# assert self.local_rank < torch.mlu.device_count(), (
# f"DP adjusted local rank {self.local_rank} is out of bounds. "
# )
self.device = torch.device(f"mlu:{self.local_rank}")
current_platform.set_device(self.device)
current_platform.check_if_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment BEFORE taking
# memory snapshot
# This ensures NCCL buffers are allocated before we measure
# available memory
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
# Set random seed.
set_random_seed(self.model_config.seed)
gc.collect()
torch.mlu.empty_cache()
# take current memory snapshot
self.init_snapshot = MemorySnapshot()
self.requested_memory = (
self.init_snapshot.total_memory
* self.cache_config.gpu_memory_utilization
)
if self.init_snapshot.free_memory < self.requested_memory:
GiB = lambda b: round(b / GiB_bytes, 2)
raise ValueError(
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
f"is less than desired GPU memory utilization "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
f"utilization or reduce GPU memory used by other processes."
)
else:
raise RuntimeError(f"Not support device type: {self.device_config.device}")
# Construct the model runner
model_runner_cls = (DPMLUModelRunner
if self._enable_moe_dp_opt() else MLUModelRunner)
self.model_runner: MLUModelRunner = model_runner_cls(
self.vllm_config, self.device)
if self.rank == 0:
# If usage stat is enabled, collect relevant info.
report_usage_stats(self.vllm_config)
@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
memory can be used for KV cache without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the free memory that can be used for KV cache in
bytes.
Tip:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
GiB = lambda b: b / GiB_bytes
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# still need a profile run which compiles the model for
# max_num_batched_tokens
self.model_runner.profile_run()
msg = (
f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} "
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for "
"KV Cache as specified by kv_cache_memory_bytes config and "
"skipped memory profiling. This does not respect the "
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
"config when you want manual control of KV cache memory "
"size. If OOM'ed, check the difference of initial free "
"memory between the current run and the previous run "
"where kv_cache_memory_bytes is suggested and update it "
"correspondingly."
)
logger.info(msg)
return kv_cache_memory_bytes
torch.mlu.empty_cache()
torch.mlu.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(self.model_runner.model_memory_usage),
) as profile_result:
self.model_runner.profile_run()
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
free_gpu_memory = profile_result.after_profile.free_memory
GiB = lambda b: b / GiB_bytes
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(
self.init_snapshot,
weights_memory=int(
self.model_runner.model_memory_usage)) as profile_result:
self.model_runner.profile_run()
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
assert self.init_snapshot.free_memory > free_gpu_memory, (
"Error in memory profiling. "
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
f"current free memory {GiB(free_gpu_memory)} GiB. "
"This happens when other processes sharing the same container "
"release GPU memory while vLLM is profiling during initialization. "
"To fix this, ensure consistent GPU memory allocation or "
"isolate vLLM in its own container."
)
self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory
)
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
logger.debug(
"Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB",
GiB(self.init_snapshot.free_memory),
self.cache_config.gpu_memory_utilization,
GiB(self.requested_memory),
)
logger.debug(
"Free memory after profiling: %.2f GiB (total), "
"%.2f GiB (within requested)",
GiB(free_gpu_memory),
GiB(free_gpu_memory - unrequested_memory),
)
logger.debug(profile_result)
logger.info_once(
"Available KV cache memory: %.2f GiB",
GiB(self.available_kv_cache_memory_bytes),
scope="local",
)
gc.collect()
self.peak_memory = profile_result.non_kv_cache_memory
self.block_memory = self.available_kv_cache_memory_bytes
return int(self.available_kv_cache_memory_bytes)
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CnMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
context = nullcontext()
with context:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes
if x not in self.vllm_config.compilation_config.cudagraph_capture_sizes
]
# We skip EPLB here since we don't want to record dummy metrics
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size, skip_eplb=True, remove_lora=False)
self.model_runner.maybe_remove_all_loras(self.model_runner.lora_config)
# Warmup and tune the kernels used during model execution before
# cuda graph capture.
kernel_warmup(self)
cuda_graph_memory_bytes = 0
if not self.model_config.enforce_eager:
cuda_graph_memory_bytes = self.model_runner.capture_model()
if self.cache_config.kv_cache_memory_bytes is None and hasattr(
self, "peak_activation_memory"
):
# Suggests optimal kv cache memory size if we rely on
# memory_profiling to guess the kv cache memory size which
# provides peak_activation_memory and a few other memory
# consumption. `memory_profiling` does not consider
# CUDAGraph memory size and may not utilize all gpu memory.
# Users may want fine-grained control to specify kv cache
# memory size.
GiB = lambda b: round(b / GiB_bytes, 2)
# empirically observed that the memory profiling may
# slightly underestimate the memory consumption.
# So leave a small buffer (=150MiB) to avoid OOM.
redundancy_buffer_memory = 150 * (1 << 20)
non_kv_cache_memory = (
self.model_runner.model_memory_usage
+ self.peak_activation_memory
+ self.non_torch_memory
+ cuda_graph_memory_bytes
)
kv_cache_memory_bytes_to_gpu_limit = (
self.init_snapshot.free_memory
- non_kv_cache_memory
- redundancy_buffer_memory
)
kv_cache_memory_bytes_to_requested_limit = (
int(self.requested_memory)
- non_kv_cache_memory
- redundancy_buffer_memory
)
msg = (
f"Free memory on device "
f"({GiB(self.init_snapshot.free_memory)}/"
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
f"Desired GPU memory utilization is "
f"({self.cache_config.gpu_memory_utilization}, "
f"{GiB(self.requested_memory)} GiB). "
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
f"config with `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_requested_limit}` "
f"({GiB(kv_cache_memory_bytes_to_requested_limit)} GiB) to fit "
f"into requested memory, or `--kv-cache-memory="
f"{kv_cache_memory_bytes_to_gpu_limit}` "
f"({GiB(kv_cache_memory_bytes_to_gpu_limit)} GiB) to fully "
f"utilize gpu memory. Current kv cache memory in use is "
f"{GiB(self.available_kv_cache_memory_bytes)} GiB."
)
logger.debug(msg)
# Warm up sampler and preallocate memory buffer for logits and other
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
)
# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = self.model_runner._dummy_run(
num_tokens=max_num_reqs,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
self.model_runner._dummy_pooler_run(hidden_states)
else:
self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode()
def execute_model(
self, scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
self.vllm_config, num_input_tokens
)
}
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
)
with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, (ModelRunnerOutput, NoneType)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert (
parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank
)
get_pp_group().send_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
return None
def _enable_moe_dp_opt(self):
'''
We will enable the MLU-optimized DP scheme for the specified MoE models,
otherwise the native DP implementation will be used.
'''
# case0 enable data parallel
enable_dp = self.parallel_config.data_parallel_size > 1
# case1 ds mla
is_ds_mla = self.model_config.is_deepseek_mla
# case2 qwen3 moe
is_supported_moe_model = hasattr(self.model_config.hf_text_config, "model_type") and \
self.model_config.hf_text_config.model_type in ('qwen3_moe', 'glm4_moe')
# case 3, private model
is_private_model = getattr(self.model_config.hf_config, "is_private", False)
return enable_dp and (is_ds_mla or is_supported_moe_model or is_private_model)
def execute_dummy_batch(self) -> None:
if self._enable_moe_dp_opt():
self.model_runner.moe_dp_execute_dummy_batch(1)
else:
self.model_runner._dummy_run(1, uniform_decode=True)
def response_remote_alloc_once(self) -> None:
self.model_runner.response_remote_alloc_once()
def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None:
from vllm.distributed.parallel_state import get_ep_group
if get_ep_group().rank == 0:
logger.info(
"[Elastic EP] Starting expert resharding before scaling down..."
)
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
assert self.model_runner.eplb_state is not None
self.model_runner.eplb_state.rearrange(
execute_shuffle=True,
global_expert_load=None,
rank_mapping=rank_mapping,
)
torch.mlu.synchronize()
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed!")
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
from vllm.config import set_current_vllm_config
from vllm.distributed.parallel_state import (
cleanup_dist_env_and_memory,
get_ep_group,
)
old_ep_size = get_ep_group().world_size
old_ep_rank = get_ep_group().rank
new_ep_size = (
reconfig_request.new_data_parallel_size
* get_tp_group().world_size
* get_pp_group().world_size
)
if new_ep_size < old_ep_size:
self._eplb_before_scale_down(old_ep_size, new_ep_size)
cleanup_dist_env_and_memory()
if (
reconfig_request.new_data_parallel_rank
== ReconfigureRankType.SHUTDOWN_CURRENT_RANK
):
assert old_ep_rank >= new_ep_size
# shutdown
return
self._reconfigure_parallel_config(reconfig_request)
with set_current_vllm_config(self.vllm_config):
init_worker_distributed_environment(
self.vllm_config,
self.rank,
self.distributed_init_method,
self.local_rank,
current_platform.dist_backend,
)
global_expert_loads = self._reconfigure_moe(old_ep_size, new_ep_size)
if new_ep_size > old_ep_size:
assert global_expert_loads is not None
self._eplb_after_scale_up(old_ep_size, new_ep_size, global_expert_loads)
def get_hfu_info(self, batch, input_len, output_len):
try:
self.model_runner.model.collect_hfu_io_effciency_info(batch, input_len, output_len)
if VLLM_DUMP_MLU_INFO_EN:
return self.model_runner.model.hfu_info, self.model_runner.model.io_efficiency
else:
return self.model_runner.model.flops_info, 0.0
except Exception as e:
raise RuntimeError(
"Model match failure when get HFU info, please check if an init method was registed."
)
def _get_latency(self, time_markers):
total_latency = 0
if not isinstance(time_markers, list):
time_markers = [time_markers]
for time_marker in time_markers:
start, end = time_marker
latency = start.elapsed_time(end)
total_latency += latency
return total_latency
def get_latency(self):
return self._get_latency(self.model_runner.time_markers)
def get_mm_encoder_latency(self):
if not hasattr(self.model_runner, "mm_time_markers"):
return None
mm_time_markers = self.model_runner.mm_time_markers
return None if len(mm_time_markers) == 0 else\
self._get_latency(mm_time_markers)
def get_memory_usage(self):
return (self.peak_memory, self.block_memory)
def recapture_model(self,
prefill_enable_mlugraph: bool,
batch_size: int,
input_len: int):
# Reset history capture context
self.model_runner.reset_capture_context(
prefill_enable_mlugraph, batch_size, input_len)
# Re-capture decode graph(full graph or peicewise graph)
self.compile_or_warm_up_model()

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""
Define KV connector functionality mixin for model runners.
"""
import copy
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, nullcontext
from typing import (
TYPE_CHECKING, # noqa: UP035
)
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_shutdown,
get_kv_transfer_group,
has_kv_transfer_group,
)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
# Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU)
class KVConnectorModelRunnerMixin_MluHijack(KVConnectorModelRunnerMixin):
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
'''
=============================
Modify by vllm_mlu
=============================
@brief: supoort disagg for mlu.
'''
kv_connector.request_remote_memory_send()
'''
==================
End of MLU Hijack
==================
'''
# This context manager must be used within an active forward context.
# It encapsulates the entire KV connector lifecycle within execute_model
@staticmethod
@contextmanager
def _get_kv_connector_output(
scheduler_output: "SchedulerOutput", wait_for_save: bool = True
) -> Generator[KVConnectorOutput, None, None]:
output = KVConnectorOutput()
# Update KVConnector with the KVConnector metadata forward().
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())
'''
=============================
Modify by vllm_mlu
=============================
@brief: supoort disagg for mlu.
'''
kv_connector.request_remote_memory_send()
'''
==================
End of MLU Hijack
==================
'''
try:
yield output
finally:
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids)
)
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = (
KVConnectorModelRunnerMixin.get_kv_connector_stats()
)
MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin,
KVConnectorModelRunnerMixin.maybe_setup_kv_connector,
KVConnectorModelRunnerMixin_MluHijack.maybe_setup_kv_connector)
MluHijackObject.apply_hijack(KVConnectorModelRunnerMixin,
KVConnectorModelRunnerMixin._get_kv_connector_output,
KVConnectorModelRunnerMixin_MluHijack._get_kv_connector_output)

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import List
from vllm.lora.request import LoRARequest
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras(self, num_loras: int) -> List[LoRARequest]:
assert num_loras > 0
assert self.lora_manager is not None
dummy_lora_requests: list[LoRARequest] = []
with self.lora_manager.dummy_lora_cache():
for idx in range(num_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"capture_graph_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=self.LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
return dummy_lora_requests
MluHijackObject.apply_hijack(LoRAModelRunnerMixin,
"add_dummy_loras",
vllm_mlu__v1__worker__LoRAModelRunnerMixin__add_dummy_loras)

View File

@@ -0,0 +1,281 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
"""A MLU quant class."""
import functools
from collections import defaultdict
from typing import Dict, Any, List, Optional, Union
import numpy as np
import torch
import torch.distributed
from vllm.distributed import (
get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size,
get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
import vllm.envs as envs
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.model_executor.layers.vocab_parallel_embedding import (VocabParallelEmbedding,
ParallelLMHead)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
from vllm.logger import init_logger
logger = init_logger(__name__)
def default_act_range_value():
return {
"x": None,
"split": None,
"is_linear": False,
"is_qkv": False,
"q_proj_size": 0,
"num_kv_head_replicas": 1,
"is_merge": False,
"input_id": [],
"self_rank": 0,
"rank": None,
"tensor_rank": None,
"tp_world_size": None,
"moe_tp_rank": None,
"moe_tp_world_size": None,
"moe_ep_rank": None,
"moe_ep_world_size": None,
"weight": None,
}
def _str_to_torch_dtype(dtype: str) -> torch.dtype:
dtype = dtype.split(".")[-1]
# STR_DTYPE_TO_TORCH_DTYPE dict does not have float16 type
return STR_DTYPE_TO_TORCH_DTYPE[dtype] if dtype != "float16" else torch.float16
class ActRangeValue:
"""
ActRangeValue for v1 MsgpackEncoder and MsgpackDecoder. This is a *WorkAround*.
The decode tensor can be wrong if we pass act range dict directly.
NOTE: here, we transfer torch.Tensor to numpy ndarray because torch.Tensor
may cause core dump.
"""
def __init__(self):
self.layer_name: str = ""
self.x: Optional[np.ndarray] = None
self.split: str = None
self.is_linear: bool = False
self.is_qkv: bool = False
self.q_proj_size: int = 0
self.num_kv_head_replicas: int = 1
self.is_merge: bool = False
self.input_id_dtype: str = None
self.input_id: Optional[List[np.ndarray]] = []
self.self_rank: int = 0
self.rank: Optional[int] = None
self.tensor_rank: Optional[int] = None
self.tp_world_size: Optional[int] = None
self.moe_tp_rank: Optional[int] = None
self.moe_tp_world_size: Optional[int] = None
self.moe_ep_rank: Optional[int] = None
self.moe_ep_world_size: Optional[int] = None
self.weight: Optional[np.ndarray] = None
self.weight_dtype: str = None
@classmethod
def serial(cls, layer_name: str, act_range: Dict[str, Any]) -> 'ActRangeValue':
instance = cls()
instance.layer_name = layer_name
instance.x = act_range.get("x")
instance.split = act_range.get("split")
instance.is_linear = act_range.get("is_linear", False)
instance.is_qkv = act_range.get("is_qkv", False)
instance.q_proj_size = act_range.get("q_proj_size", 0)
instance.num_kv_head_replicas = act_range.get("num_kv_head_replicas", 1)
instance.is_merge = act_range.get("is_merge", False)
instance.input_id = act_range.get("input_id", [])
instance.self_rank = act_range.get("self_rank", 0)
instance.rank = act_range.get("rank")
instance.tensor_rank = act_range.get("tensor_rank")
instance.tp_world_size = act_range.get("tp_world_size")
instance.moe_tp_rank = act_range.get("moe_tp_rank")
instance.moe_tp_world_size = act_range.get("moe_tp_world_size")
instance.moe_ep_rank = act_range.get("moe_ep_rank")
instance.moe_ep_world_size = act_range.get("moe_ep_world_size")
instance.weight = act_range.get("weight")
if instance.x is not None:
instance.x = instance.x.numpy()
# input_id and weight are used for debug
if isinstance(instance.input_id, torch.Tensor):
instance.input_id_dtype = str(instance.input_id.dtype)
instance.input_id = instance.input_id.float().numpy()
else:
input_id_np = []
for input_id in instance.input_id:
instance.input_id_dtype = str(input_id.dtype)
input_id_np.append(input_id.float().numpy())
instance.input_id = input_id_np
if instance.weight is not None:
instance.weight_dtype = str(instance.weight.dtype)
instance.weight = instance.weight.float().numpy()
return instance
def deserial(self) -> Dict[str, Any]:
act_range = self.to_dict()
if self.x is not None:
act_range["x"] = torch.from_numpy(self.x)
if self.input_id is not None:
if isinstance(self.input_id, torch.Tensor):
act_range["input_id"] = torch.from_numpy(self.input_id).to(
_str_to_torch_dtype(self.input_id_dtype))
else:
input_id_tensor = []
for input_id in self.input_id:
input_id_tensor.append(torch.from_numpy(input_id).to(
_str_to_torch_dtype(self.input_id_dtype)))
act_range["input_id"] = input_id_tensor
if self.weight_dtype is not None:
act_range["weight"] = torch.from_numpy(self.weight).to(
_str_to_torch_dtype(self.weight_dtype))
return act_range
def to_dict(self) -> Dict[str, Any]:
return {
"x": self.x,
"split": self.split,
"is_linear": self.is_linear,
"is_qkv": self.is_qkv,
"q_proj_size": self.q_proj_size,
"num_kv_head_replicas": self.num_kv_head_replicas,
"is_merge": self.is_merge,
"input_id": self.input_id,
"self_rank": self.self_rank,
"rank": self.rank,
"tensor_rank": self.tensor_rank,
"tp_world_size": self.tp_world_size,
"moe_tp_rank": self.moe_tp_rank,
"moe_tp_world_size": self.moe_tp_world_size,
"moe_ep_rank": self.moe_ep_rank,
"moe_ep_world_size": self.moe_ep_world_size,
"weight": self.weight,
}
def __repr__(self) -> str:
return f"layer: {self.layer_name}, ActRangeValue({self.to_dict()})"
class MLUWorkerQuant(object):
'''
mlu quant
'''
def stat_tensor(self, name, tensor, act_range, key, dim):
logger.debug(f"name:{name}, key:{key}, dim:{dim}, tensor.shape:{tensor.shape}")
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs()
comming_max = torch.max(tensor, dim=dim)[0].float()
if act_range[name][key] is None:
act_range[name][key] = comming_max
else:
act_range[name][key] = torch.max(act_range[name][key], comming_max)
def stat_input_hook(self, m, x, y, name, act_range, is_linear, is_save_input_id):
if isinstance(x, tuple):
x = x[0]
if isinstance(y, tuple):
y = y[0]
logger.debug(f"name:{name}, x.shape:{x.shape}, y.shape:{y.shape}, m.weight.shape:{m.weight.shape}")
if is_linear:
self.stat_tensor(name, x, act_range, "x", 0)
if act_range[name]["is_qkv"] and is_save_input_id and ".0." in name:
x_cpu = x.clone().to("cpu")
act_range[name]["input_id"].append(x_cpu)
def setup_smooth_hook(self, is_save_input_id: bool = False, is_save_moe_info: bool = False):
models = [self.model_runner.model]
if hasattr(self.model_runner, "drafter") and self.model_runner.drafter is not None:
models += [self.model_runner.drafter.model]
self.act_range = defaultdict(default_act_range_value)
self.hooks = []
linear_class_list = (ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
other_class_list = (VocabParallelEmbedding, ParallelLMHead)
class_list = linear_class_list + other_class_list
row_class_list = (RowParallelLinear)
for model in models:
for name, m in model.named_modules():
if isinstance(m, FeedForward):
m.use_bt_ffn = False
if isinstance(m, SparseMoeMlp):
m.is_use_fused_moe = False
if isinstance(m, DeepseekV2MLAAttention):
m.use_fused_mla_qkv = False
if isinstance(m, class_list):
is_linear = True if isinstance(m, linear_class_list) else False
split_type = "row" if isinstance(m, row_class_list) else "col"
self.act_range[name]["split"] = split_type
self.act_range[name]["is_linear"] = is_linear
if isinstance(m, QKVParallelLinear):
self.act_range[name]["is_qkv"] = True
self.act_range[name]["q_proj_size"] = m.num_heads * m.head_size
self.act_range[name]["num_kv_head_replicas"] = m.num_kv_head_replicas
self.act_range[name]["is_merge"] = isinstance(m, MergedColumnParallelLinear)
if is_save_moe_info:
self.act_range[name]["rank"] = torch.distributed.get_rank()
self.act_range[name]["tensor_rank"] = get_tensor_model_parallel_rank()
self.act_range[name]["tp_world_size"] = get_tensor_model_parallel_world_size()
self.act_range[name]["moe_tp_rank"] = get_moe_tensor_parallel_rank()
self.act_range[name]["moe_tp_world_size"] = get_moe_tensor_parallel_world_size()
self.act_range[name]["moe_ep_rank"] = get_moe_expert_parallel_rank()
self.act_range[name]["moe_ep_world_size"] = get_moe_expert_parallel_world_size()
if ".expert." in name:
self.act_range[name]["weight"] = m.weight
logger.info(f"rank:{self.rank}, add hook to {name}, is_linear:{is_linear}, split_type:{split_type}")
self.hooks.append(m.register_forward_hook(functools.partial(self.stat_input_hook,
name=name, act_range=self.act_range,
is_linear=is_linear,
is_save_input_id=is_save_input_id)))
def remove_hooks(self):
for h in self.hooks:
h.remove()
def get_act_range(self):
act_range = defaultdict(default_act_range_value)
for layer_name, layer_range in self.act_range.items():
for tensor_key, tensor_value in layer_range.items():
if isinstance(tensor_value, torch.Tensor):
act_range[layer_name][tensor_key] = tensor_value.to("cpu")
elif tensor_key == "input_id" and isinstance(tensor_value, list):
input_id_len = len(tensor_value)
for i in range(input_id_len):
if isinstance(tensor_value[i], torch.Tensor):
act_range[layer_name][tensor_key].append(tensor_value[i].to("cpu"))
else:
act_range[layer_name][tensor_key].append(tensor_value[i])
else:
act_range[layer_name][tensor_key] = tensor_value
serialization_result = []
for layer_name, layer_range in act_range.items():
serialization_result.append(ActRangeValue.serial(layer_name, layer_range))
return serialization_result
@torch.no_grad()
def get_named_parameters(self):
name_parameters = {}
for name, param in self.model_runner.model.named_parameters():
name_parameters[name] = param.to("cpu")
return name_parameters