init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -23,5 +23,7 @@ def register():
def register_model():
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
from .models import register_model
register_model()

View File

@@ -34,6 +34,8 @@ class AscendConfig:
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
self.use_sfa = self.is_deepseek_sfa
torchair_graph_config = additional_config.get("torchair_graph_config",
{})
@@ -43,13 +45,26 @@ class AscendConfig:
"ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig(
ascend_scheduler_config)
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None)
self.expert_map_record_path = additional_config.get(
"expert_map_record_path",
None) # Provide path to export expert map
self.init_redundancy_expert = additional_config.get(
"init_redundancy_expert", 0)
self.dynamic_eplb = additional_config.get("dynamic_eplb", False)
self.num_iterations_eplb_update = additional_config.get(
"num_iterations_eplb_update", 400)
self.gate_eplb = additional_config.get("gate_eplb", False)
self.num_wait_worker_iterations = additional_config.get(
"num_wait_worker_iterations", 30)
self.chunked_prefill_for_mla = additional_config.get(
"chunked_prefill_for_mla", False)
self.enable_shared_expert_dp = additional_config.get(
"enable_shared_expert_dp", False
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
@@ -61,6 +76,24 @@ class AscendConfig:
raise AssertionError(
"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
)
self.oproj_tensor_parallel_size = additional_config.get(
"oproj_tensor_parallel_size", None)
if self.oproj_tensor_parallel_size is not None:
logger.info(
f"Enable oproj_tensor_parallel_size={self.oproj_tensor_parallel_size} in pure DP scenario"
)
if vllm_config.parallel_config.tensor_parallel_size != 1:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in the pure DP scenario"
)
if not self.torchair_graph_config.enabled:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in graph mode"
)
if vllm_config.kv_transfer_config is None or not vllm_config.kv_transfer_config.is_kv_consumer:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
class TorchairGraphConfig:
@@ -81,10 +114,10 @@ class TorchairGraphConfig:
"graph_batch_sizes_init", False)
self.enable_multistream_mla = torchair_graph_config.get(
"enable_multistream_mla", False)
self.enable_multistream_moe = torchair_graph_config.get(
"enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True)
self.enable_frozen_parameter = torchair_graph_config.get(
"enable_frozen_parameter", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)
if not isinstance(self.graph_batch_sizes, list):
@@ -117,10 +150,6 @@ class TorchairGraphConfig:
raise RuntimeError(
"enable_multistream_mla is valid only when Torchair graph mode is enabled"
)
if self.enable_multistream_moe:
raise RuntimeError(
"enable_multistream_moe is valid only when Torchair graph mode is enabled"
)
if self.enable_kv_nz:
raise RuntimeError(
"enable_kv_nz is valid only when Torchair graph mode is enabled"
@@ -200,14 +229,8 @@ def check_ascend_config(vllm_config, enforce_eager):
"it has been disabled automatically.")
# aclgraph case
else:
# aclgraph doesn't work with deepseek model and only qwen model is well tested.
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "deepseek" in model_type:
raise NotImplementedError(
"ACL Graph does not support deepseek. Please "
"try torchair graph mode to serve deepseek models on vllm-ascend."
" Or set `enforce_eager=True` to use eager mode.")
if "qwen" not in model_type:
logger.warning(
"ACL Graph is currently experimental. Please "

View File

@@ -11,6 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
class FusedMoEState(Enum):
@@ -22,6 +23,13 @@ class FusedMoEState(Enum):
All2AllSeq = 5
class MoECommType(Enum):
ALLGATHER = 0
MC2 = 1
ALLTOALL = 2
NAIVE_MULTICAST = 3
# TODO(zzzzwwjj): add soc_version to choose branch
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
@@ -42,18 +50,6 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool,
return FusedMoEState.MC2
def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str:
if ep_size == 1:
return "TokenDispatcherWithAllGather"
if ep_size < 16:
return "TokenDispatcherWithAll2AllV"
if with_prefill:
return "TokenDispatcherWithAll2AllV"
return "TokenDispatcherWithMC2"
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
@@ -64,10 +60,12 @@ def set_ascend_forward_context(
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: str = "",
moe_comm_type: Optional[MoECommType] = None,
num_actual_tokens: Optional[int] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
@@ -82,8 +80,13 @@ def set_ascend_forward_context(
batch_descriptor=batch_descriptor,
):
forward_context = get_forward_context()
forward_context.moe_comm_method_name = moe_comm_method + "commimpl"
from vllm_ascend.ops.moe.moe_comm_method import get_moe_comm_method
forward_context.moe_comm_type = moe_comm_type
forward_context.moe_comm_method = get_moe_comm_method(moe_comm_type)
forward_context.with_prefill = with_prefill
tp_world_size = get_tensor_model_parallel_world_size()
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -95,16 +98,63 @@ def set_ascend_forward_context(
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
get_token_dispatcher
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
dispatcher = get_token_dispatcher(dispatcher_name)
forward_context.token_dispatcher = dispatcher
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
# set for sequence parallelism, 1000 is the batch size concurrency threshold for enabling the flashcomm_v1 or sequence_parallelism feature.
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if sp_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
# set this for rope forward_oot using
forward_context.is_first_layer = True
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer
# set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
forward_context.layer_idx is not None and \
num_tokens is not None and num_tokens < 500
if prefetch_mlp_enabled:
forward_context.prefetch_stream = prefetch_stream
forward_context.model_instance = model_instance
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.
#
# set for addrmsnorm+quant fusion.
# this optim now just support dense models due to the specific operators used.
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
forward_context.layer_idx is not None
if addrmsnorm_quant_fusion_enabled:
forward_context.model_instance = model_instance
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens
@@ -120,7 +170,6 @@ def set_ascend_forward_context(
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size

View File

@@ -39,11 +39,22 @@ class AttentionMaskBuilder:
self,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device = None,
):
# NOTE: The device argument specifies the target NPU
# to be used for the newly added FIA operator.
# Only pass this parameter when using the new FIA operator.
attn_mask = _generate_attn_mask(max_seq_len, dtype)
self._seq_len_cached = attn_mask.shape[0]
self.attn_mask_cache = attn_mask
self.device = device
if torch.version.cann.startswith("8.3"):
assigned_mask_dim = 2048
self.chunked_prefill_attn_mask = torch.triu(
torch.ones(assigned_mask_dim, assigned_mask_dim),
diagonal=1).to(torch.int8).to(device)
@staticmethod
def get_mask_scale_factor(dtype: torch.dtype = torch.float16):
@@ -62,28 +73,32 @@ class AttentionMaskBuilder:
device: torch.device):
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device)
).to(device, non_blocking=True)
def get_splitfuse_attn_mask(
self,
seq_lens: torch.Tensor,
position: torch.Tensor,
dtype: torch.dtype,
device: torch.device,
seq_lens: torch.Tensor = None,
position: torch.Tensor = None,
dtype: torch.dtype = None,
device: torch.device = None,
) -> torch.Tensor:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
if torch.version.cann.startswith("8.3"):
return self.chunked_prefill_attn_mask
else:
if dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(
"splitfuse_attn_mask now only supports bf16 and fp16")
max_seq_len = max(seq_lens, default=0)
self._update_attn_cache(max_seq_len, dtype)
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
# is not the same. Fix this in the future when kernel is ready.
mask_scale_factor = AttentionMaskBuilder.get_mask_scale_factor(
dtype)
attn_mask = torch.index_select(self.attn_mask_cache,
dim=0,
index=position)[:, :max_seq_len]
attn_mask *= mask_scale_factor
return attn_mask.contiguous().to(device, non_blocking=True)
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype):
if seqlen > self._seq_len_cached:

View File

@@ -17,24 +17,27 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Tuple, Type
from typing import ClassVar, List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch_npu
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import cdiv, direct_register_custom_op
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
from vllm_ascend.worker.npu_input_batch import InputBatch
class AscendAttentionBackend(AttentionBackend):
@@ -52,10 +55,6 @@ class AscendAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AscendMetadata"]:
return AscendMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["AscendAttentionMetadataBuilder"]:
return AscendAttentionMetadataBuilder
@@ -111,6 +110,10 @@ class AscendAttentionBackend(AttentionBackend):
key_caches[dst_indices] = key_caches[src_indices]
value_caches[dst_indices] = value_caches[src_indices]
@staticmethod
def get_supported_block_size() -> list[int]:
return [64]
class AscendAttentionState(Enum):
PrefillNoCache = 0
@@ -155,48 +158,50 @@ class AscendMetadata:
# *************************** Other Properties *************************** #
enable_dbo_across_dp: bool = False
is_only_prefill: bool = False
class AscendAttentionMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
# Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
vllm_config.cache_config.block_size)
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0])
def reorder_batch(self, input_batch: "InputBatch",
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
model: Optional[nn.Module] = None,
):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
num_reqs
+ 1]
block_table = common_attn_metadata.block_table_tensor
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_table[:num_reqs])
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
self.device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
attn_mask = common_attn_metadata.attn_mask
attn_state = common_attn_metadata.attn_state
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
@@ -225,8 +230,25 @@ class AscendAttentionMetadataBuilder:
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
is_only_prefill=common_attn_metadata.is_only_prefill)
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
return attn_metadata
def build_for_graph_capture(
self,
common_attn_metadata: AscendCommonAttentionMetadata,
attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly,
):
if attn_state == AscendAttentionState.DecodeOnly:
attn_metadata = self.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)
attn_metadata.attn_state = attn_state
return attn_metadata
@@ -265,20 +287,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.key_cache = None
self.value_cache = None
def _repeat_kv(self, hidden_states: torch.Tensor,
n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, None, :, :].expand(
num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_key_value_heads * n_rep, slen,
head_dim)
def _forward_prefill_no_cache(
self,
query: torch.Tensor,
@@ -304,34 +312,15 @@ class AscendAttentionBackendImpl(AttentionImpl):
mask = torch_npu.npu_format_cast(mask.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
if self.sliding_window is not None and \
attn_metadata.attn_mask.shape[0] > self.sliding_window:
key = self._repeat_kv(key, self.num_heads // self.num_kv_heads)
value = self._repeat_kv(value, self.num_heads // self.num_kv_heads)
output, _ = torch_npu.npu_fused_infer_attention_score(
query,
key,
value,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
input_layout="TND",
pre_tokens=self.sliding_window,
scale=self.scale,
actual_seq_lengths=attn_metadata.seq_lens,
actual_seq_lengths_kv=attn_metadata.seq_lens)
output = output.view(num_tokens, self.num_heads, self.head_size)
else:
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
torch_npu._npu_flash_attention(query=query,
key=key,
value=value,
mask=mask,
seq_len=attn_metadata.seq_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=output)
assert output is not None
return output[:num_tokens, :, :]
@@ -372,7 +361,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
# seq_lens_tensor needs to be transferred to the device for 310P.
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
if self.sliding_window is not None:
if self.sliding_window is not None and attn_metadata.seq_lens.shape[
0] == query.size(0):
batch_size = attn_metadata.seq_lens.shape[0]
block_size = 128
query = query.view(batch_size, 1, self.num_heads * self.head_size)
@@ -399,16 +389,53 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = output.view(batch_size, self.num_heads, self.head_size)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
self.value_cache,
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
return output
def _forward_v1_style(
@@ -449,18 +476,43 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata.seq_lens = \
attn_metadata.seq_lens.to(device=query.device)
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
if torch.version.cann.startswith("8.3"):
# TODO:The npu_fused_infer_attention_score op is planned to
# be utilized in a wider range in upcoming versions.
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
key = self.key_cache.view( # type: ignore
num_block, block_size, -1)
value = self.value_cache.view( # type: ignore
num_block, block_size, -1)
output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.query_start_loc[1:],
actual_seq_lengths_kv=attn_metadata.seq_lens,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
else:
torch_npu._npu_paged_attention_splitfuse(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
mask=attn_metadata.attn_mask,
block_table=attn_metadata.block_tables,
seq_len=attn_metadata.query_lens,
context_lens=attn_metadata.seq_lens,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
out=output)
return output
def forward(
@@ -554,12 +606,18 @@ class AscendAttentionBackendImpl(AttentionImpl):
output)
# Normal V1 situation.
else:
if torch.version.cann.startswith("8.3"):
# npu_fused_infer_attention_score does not support cases
# where query.shape[0] != attn_metadata.query_start_loc[-1].
# Thus we need unpad it here.
num_tokens = attn_metadata.query_start_loc[-1]
query = query[:num_tokens]
output = self._forward_v1_style(query, attn_metadata, output)
# to make in-place change to the output tensor
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
output = output.view(num_tokens, self.num_heads, self.head_size)
ori_output[:, :, :] = output[:num_tokens, :, :]
ori_output[:num_tokens, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)
@@ -570,8 +628,11 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
@@ -582,6 +643,7 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Type, TypeVar
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
import torch
import torch_npu
@@ -12,15 +13,17 @@ from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
wait_for_kv_layer_from_connector)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch
@@ -164,6 +167,9 @@ M = TypeVar("M", bound=AscendMLAMetadata)
class AscendMLAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -171,6 +177,8 @@ class AscendMLAMetadataBuilder:
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendMLAMetadata] = None):
@@ -185,7 +193,16 @@ class AscendMLAMetadataBuilder:
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
self.reorder_batch_threshold = self.decode_threshold
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
@@ -265,6 +282,7 @@ class AscendMLAMetadataBuilder:
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendMLAMetadata:
@@ -272,7 +290,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
@@ -284,11 +301,7 @@ class AscendMLAMetadataBuilder:
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
num_actual_tokens].to(
device,
non_blocking=
True)
slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
@@ -376,11 +389,12 @@ class AscendMLAMetadataBuilder:
decode_metadata = None
if num_decodes > 0:
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decode_tokens]
seq_lens = seq_lens[:num_decodes]
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decode_tokens, ...]
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
@@ -481,17 +495,12 @@ class AscendMLAImpl(MLAAttentionImpl):
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.speculative_config = vllm_config.speculative_config
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
@@ -663,84 +672,47 @@ class AscendMLAImpl(MLAAttentionImpl):
self.v_head_dim,
dtype=q_nope.dtype,
device=q_nope.device)
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
query = torch.cat((q_nope, q_pe), dim=-1)
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
value=value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
elif self.chunked_prefill_for_mla:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=q_nope.device)
if self.prefill_mask is None:
self.prefill_mask = torch.triu(
torch.ones(self.ring_mla_mask_size,
self.ring_mla_mask_size,
device=q_nope.device,
dtype=q_nope.dtype), 1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
else:
query = torch.cat((q_nope, q_pe), dim=-1)
attn_output_torch = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output_torch,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=q_nope.device)
if self.prefill_mask is None:
if q_nope.dtype == torch.float16:
mask_value = torch.finfo(torch.float32).min
else:
mask_value = 1
prefill_mask = torch.triu(
torch.ones(self.ring_mla_mask_size,
self.ring_mla_mask_size,
device=q_nope.device,
dtype=q_nope.dtype), 1)
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
0).to(q_nope.dtype)
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=self.prefill_mask,
seqlen=torch.tensor(
attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding,
AscendAttentionState.PrefillCacheHit
] and not self.chunked_prefill_for_mla:
attn_output = attn_output_torch
return attn_output
def exec_kv_decode(
@@ -785,7 +757,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA"
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
@@ -840,8 +812,11 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qk_rope_head_dim)
input_layout = "BNSD"
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
assert num_tokens % self.spec_token_num == 0
if attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill
] and self.speculative_config is not None:
# Use TND layout for pure SpecDecoding and SpecDecoding in ChunkedPrefill
input_layout = "TND"
# [bs * q_seq_len, num_heads_per_rank, dim]
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
@@ -887,8 +862,8 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
@@ -917,6 +892,8 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
wait_for_kv_layer_from_connector(layer_name)
# Preprocess for decode tokens
if has_decode:
decode_q_c = q_c[:num_decode_tokens]
@@ -963,6 +940,7 @@ class AscendMLAImpl(MLAAttentionImpl):
def forward(
self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
@@ -989,7 +967,8 @@ class AscendMLAImpl(MLAAttentionImpl):
# MLA Preprocess
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
layer_name, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv)
if decode_preprocess_res is not None:
# MLA Preprocess for decoding
@@ -1047,4 +1026,8 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input
has_prefill = attn_metadata.num_prefills > 0
if has_prefill:
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
return output_padded

View File

@@ -0,0 +1,986 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
import torch
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class AscendSFABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_SFA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AscendSFAMetadata
@staticmethod
def get_builder_cls():
return AscendSFAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["AscendSFAImpl"]:
return AscendSFAImpl
@dataclass
class AscendSFAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
sin: torch.Tensor
cos: torch.Tensor
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class AscendSFADecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
actual_seq_lengths_q: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
query_lens: Optional[list[int]] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: Optional[AscendSFADecodeMetadata] = None
prefill: Optional[AscendSFAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFAMetadata"]:
"""Split metadata for multi-stream with AscendSFAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendSFAMetadata] = None):
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
if metadata_cls is not None else AscendSFAMetadata # type: ignore
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_tokens <= self.decode_threshold:
decodes.append(i)
else:
prefills.append(i)
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
return modified_batch
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens].to(
device,
non_blocking=True)
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
0].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendSFAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
actual_query_lens = torch.tensor(query_lens[reqs_start:],
dtype=torch.int32).npu()
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
dim=0).to(torch.int32)
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
prefill_metadata = AscendSFAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens_prefill_sfa,
seq_lens=seq_lens_prefill_sfa,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
decode_metadata = None
if num_decodes > 0:
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
torch.int32).npu()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendSFADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
class PrefillSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
class DecodeSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
# nope_cache: Optional[torch.Tensor] = None
# rope_cache: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
bsz: Optional[int] = None
class AscendSFAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs['q_lora_rank']
self.kv_lora_rank = kwargs['kv_lora_rank']
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.indexer = kwargs['indexer']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = self.num_heads // self.tp_size
if self.q_a_proj is not None:
self.q_b_proj = self.q_proj
else:
self.q_b_proj = None
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
# indexer param
self.dim = self.indexer.dim
self.n_heads: int = self.indexer.n_heads # 64
self.head_dim: int = self.indexer.head_dim # 128
self.index_topk: int = self.indexer.index_topk # 2048
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.softmax_scale = self.indexer.softmax_scale
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.cp_size = 1
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
# SFA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
# 3. If need_gather_q_kv, perform all_gather.
# 4. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if need_gather_q_kv:
# q_c = get_tp_group().all_gather(q_c, 0)
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
hidden_states = get_tp_group().all_gather(hidden_states, 0)
# hidden_states_decode = hidden_states[:num_decode_tokens]
# if self.q_a_proj is not None:
# npu_prefetch(self.q_a_proj.weight,
# hidden_states,
# enabled=self.enable_prefetch)
# ckq = self.q_a_proj(hidden_states) # q down
# q_c = self.q_a_layernorm(ckq) # q down layernorm
# else:
# q_c = hidden_states
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
# Process for shared_expert_dp
decode_preprocess_res = None
prefill_preprocess_res = None
# Preprocess for decode tokens
if has_decode:
q_len = 1
hidden_states_decode = hidden_states[:num_decode_tokens]
decode_kq = self.q_a_proj(hidden_states_decode) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
decode_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_decode) # c_kv
# decode_q_c = q_c[:num_decode_tokens]
decode_slot_mapping = attn_metadata.slot_mapping[:
num_decode_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
decode_q = self.q_b_proj(decode_q_c)
bsz, _ = decode_q.shape
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
decode_q_nope, decode_q_pe = torch.split(
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
decode_q_nope = decode_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
decode_q_nope = (torch.matmul(decode_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(bsz, q_len,
self.num_heads,
self.kv_lora_rank))
# stream2 kv
key_cache = kv_cache[0]
value_cache = kv_cache[1]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
decode_kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
decode_slot_mapping.to(torch.int64),
value_cache,
key_cache,
c_kv_scale=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode='PA') # adapter NZ
# nz_block_size = 16
# KVCACHE_NZ_DIM = 16
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
sin) # BNSD
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
topk_indices = self.indexer_select(hidden_states_decode,
decode_q_c,
attn_metadata=attn_metadata,
kv_cache=kv_cache)
query_states = (decode_q_nope, decode_q_pe)
key_states = (decode_k_nope, decode_k_rope)
decode_preprocess_res = DecodeSFAPreprocessResult(
q_nope=decode_q_nope,
q_pe=decode_q_pe,
# nope_cache = nope_cache,
# rope_cache = rope_cache,
topk_indices=topk_indices,
query_states=query_states,
key_states=key_states,
bsz=bsz,
)
# Preprocess for prefill tokens
if has_prefill:
bsz = 1
hidden_states_prefill = hidden_states[
num_decode_tokens:num_actual_tokens]
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
prefill_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_prefill) # c_kv
# prefill_q_c = q_c[
# num_decode_tokens:num_actual_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# prefill_kv_no_split = kv_no_split[
# num_decode_tokens:num_actual_tokens]
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
prefill_qr = prefill_q_c
prefill_q = self.q_b_proj(prefill_qr)
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_nope, prefill_q_pe = torch.split(
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
prefill_q_nope = prefill_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
prefill_q_nope = (torch.matmul(prefill_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(-1, self.num_heads,
self.kv_lora_rank))
prefill_q_pe = prefill_q_pe.unsqueeze(2)
# stream2 kv
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
cos_q, sin_q = cos, sin
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
prefill_q_pe = torch_npu.npu_interleave_rope(
prefill_q_pe, cos_q, sin_q) # BNSD
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
prefill_latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
self.kv_a_layernorm.weight,
cos.view(-1, 1, 1, self.qk_rope_head_dim),
sin.view(-1, 1, 1, self.qk_rope_head_dim),
prefill_slot_mapping.to(torch.int64),
rope_cache,
nope_cache,
k_rope_scale=None,
c_kv_scale=None,
k_rope_offset=None,
c_kv_offset=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA")
topk_indices = self.indexer_select(x=hidden_states_prefill,
qr=prefill_qr,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
query_states = (prefill_q_nope, prefill_q_pe)
key_states = (prefill_k_nope, prefill_k_pe)
prefill_preprocess_res = PrefillSFAPreprocessResult(
q_nope=prefill_q_nope,
q_pe=prefill_q_pe,
topk_indices=topk_indices,
k_nope=prefill_k_nope,
k_pe=prefill_k_pe,
query_states=query_states,
key_states=key_states,
)
return decode_preprocess_res, prefill_preprocess_res
def forward(
self,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
# SFA Preprocess
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
if decode_preprocess_res is not None:
# bsz, q_len, _, _ = query_states[0].shape
decode_attn_output = self.apply_attention_fusion(
query_states=decode_preprocess_res.query_states,
key_states=decode_preprocess_res.key_states,
attn_metadata=attn_metadata,
topk_indices=decode_preprocess_res.topk_indices)
o_proj_input[:num_decode_tokens] = decode_attn_output
if prefill_preprocess_res is not None:
prefill_attn_output = self.apply_attention_fusion(
query_states=prefill_preprocess_res.query_states,
key_states=prefill_preprocess_res.key_states,
attn_metadata=attn_metadata,
topk_indices=prefill_preprocess_res.topk_indices)
o_proj_input[num_decode_tokens:] = prefill_attn_output
output[...] = self.mla_epilog(o_proj_input, absorb=True)
return output
def apply_attention_fusion(self, query_states, key_states, topk_indices,
attn_metadata: M):
# repeat k/v heads if n_kv_heads < n_heads
q_nope, q_pe = query_states
k_nope, k_rope = key_states
if attn_metadata.prefill is not None:
prefill_metadata = attn_metadata.prefill
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=prefill_metadata.block_table,
actual_seq_lengths_query=prefill_metadata.query_lens,
actual_seq_lengths_kv=prefill_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
elif attn_metadata.decode is not None:
decode_metadata = attn_metadata.decode
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.decode.block_table,
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=decode_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
slc_fa_fusion = slc_fa_fusion.squeeze(1)
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
# input shape [N//attn_tp_size, T(bs*q_len), D]
# output shape [T(bs*q_len), N//attn_tp_size, D]
attn_output = torch.matmul(slc_fa_fusion,
self.kv_b_proj_w_v).transpose(1, 0).reshape(
-1, self.num_heads * self.v_head_dim)
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
return attn_output
def mla_epilog(self,
attn_output: torch.Tensor = None,
absorb: bool = False):
# TODO: need to check
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
-1),
is_prefill=True,
is_force_scatter=False)
return attn_output
def indexer_select(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
):
if attn_metadata.prefill is not None:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
actual_seq_lengths_query = attn_metadata.prefill.query_lens
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
block_table = attn_metadata.prefill.block_table
elif attn_metadata.decode is not None:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
actual_seq_lengths_key = attn_metadata.decode.seq_lens
block_table = attn_metadata.decode.block_table
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
# q process in new stream
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
weights = self.weights_proj(x)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3)
return topk_indices

View File

@@ -1,7 +1,11 @@
from dataclasses import dataclass
from typing import Any
from typing import Any, List
import torch
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
@dataclass
@@ -21,6 +25,13 @@ class AscendCommonAttentionMetadata:
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
seq_lens: torch.Tensor
"""same to seq_lens_cpu, for compatibility with some new attn metadata
(such as GDN)."""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
@@ -34,7 +45,7 @@ class AscendCommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping_cpu: torch.Tensor
slot_mapping: torch.Tensor
actual_seq_lengths_q: list[int]
@@ -93,3 +104,34 @@ def split_decodes_and_prefills(
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.wait_for_layer_load(layer_name)
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
connector = get_kv_transfer_group()
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
# TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)

View File

@@ -3,10 +3,12 @@
import dataclasses
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch_npu
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphOptions
@@ -15,7 +17,8 @@ from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.utils import weak_ref_tensors
from ..utils import weak_ref_tensors
@dataclasses.dataclass
@@ -35,10 +38,10 @@ class ACLGraphWrapper:
The workflow of this wrapper in the aclgraph dispatching is as follows:
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
PIECEWISE).
2. At runtime, the wrapper receives a runtime_mode and a
batch_descriptor(key) from the forward context and blindly trust them
for aclgraph dispatching.
for aclgraph dispatching.
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
wrapper, just call the runnable directly.
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
@@ -47,9 +50,9 @@ class ACLGraphWrapper:
Note: ACLGraphWrapper does not store persistent buffers or copy any
runtime inputs into that buffers for replay. We assume implementing them
is done outside of the wrapper. That is because we do not make any
is done outside of the wrapper. That is because we do not make any
assumption on the dynamic shape (batch size) of the runtime inputs, as a
trade-off for staying orthogonal to compilation logic. Nevertheless,
trade-off for staying orthogonal to compilation logic. Nevertheless,
tracing and checking the input addresses to be consistent during replay is
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
@@ -146,6 +149,7 @@ class ACLGraphWrapper:
patch("torch.npu.empty_cache", lambda: None))
# mind-exploding: carefully manage the reference and memory.
forward_context.capturing = True
with torch.npu.graph(aclgraph, pool=self.graph_pool):
# `output` is managed by pytorch's aclgraph pool
output = self.runnable(*args, **kwargs)
@@ -183,3 +187,74 @@ class ACLGraphWrapper:
logger.info_once("Replaying aclgraph")
entry.aclgraph.replay()
return entry.output
def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -20,14 +20,19 @@ from typing import Type, Union
from vllm.config import SchedulerConfig
MAX_INT = 2147483647
@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = MAX_INT
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
enable_pd_transfer: bool = False
decode_max_num_seqs: int = 0
@classmethod
def initialize_from_config(
@@ -41,10 +46,13 @@ class AscendSchedulerConfig(SchedulerConfig):
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):
@@ -65,20 +73,36 @@ class AscendSchedulerConfig(SchedulerConfig):
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
# concurrent partial prefills. Default is inf
if self.max_long_partial_prefills is None:
self.max_long_partial_prefills = MAX_INT
self.long_prefill_token_threshold = MAX_INT
if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
self.long_prefill_token_threshold = MAX_INT
else:
self.long_prefill_token_threshold = \
max(1, int(self.max_model_len * 0.04))
if self.max_long_partial_prefills < 0:
raise ValueError(
f"max_long_partial_prefills must be non-negative, but got "
f"{self.max_long_partial_prefills}")
if self.long_prefill_token_threshold < 0:
raise ValueError(
f"long_prefill_token_threshold must be non-negative, but got "
f"{self.long_prefill_token_threshold}")
if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
)
if self.is_multimodal_model:
raise NotImplementedError(
"currently AscendScheduler only supports LLM models.")
if self.num_scheduler_steps > 1:
raise NotImplementedError(
"currently AscendScheduler doesn't support multi-step.")
if self.send_delta_data:
raise NotImplementedError(
"currently AscendScheduler doesn't support send_delta_data.")
if self.delay_factor > 0:
if getattr(self, "scheduler_delay_factor", 0) > 0:
raise NotImplementedError(
"currently AscendScheduler doesn't support scheduler_delay_factor."
)

View File

@@ -23,6 +23,7 @@ from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
@@ -31,13 +32,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
else:
KVCacheBlocks = None
class AscendScheduler(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler
@@ -58,6 +52,15 @@ class AscendScheduler(Scheduler):
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []
self.finished_prefill_reqs: deque[Request] = deque()
enable_pd_transfer = getattr(self.scheduler_config,
'enable_pd_transfer', False)
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.phase = "" if not enable_pd_transfer else "prefill"
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
decode_max_num_seqs)
def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
@@ -66,12 +69,14 @@ class AscendScheduler(Scheduler):
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req_to_new_block_ids: dict[str, list[list[int]]] = {}
else:
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
@@ -85,9 +90,33 @@ class AscendScheduler(Scheduler):
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()
if self.phase == "prefill":
remaining_running_reqs = []
for request in self.running:
# move request has finished prefill to finished_prefill_reqs
if request.num_tokens > request.num_prompt_tokens:
self.finished_prefill_reqs.append(request)
else:
remaining_running_reqs.append(request)
self.running = remaining_running_reqs
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
# Skip long prompt requests in prefill stage.
# long_prefill_budget is float('inf') if not use.
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
long_prefill_budget = float('inf')
long_prefill_token_threshold = float('inf')
else:
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold
# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
if len(self.running) == (self.decode_max_num_running_reqs
if self.phase == "decode" else
self.max_num_running_reqs):
break
request = self.waiting[0]
@@ -139,6 +168,9 @@ class AscendScheduler(Scheduler):
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
@@ -176,6 +208,17 @@ class AscendScheduler(Scheduler):
assert num_new_tokens > 0
blocks = new_computed_blocks.blocks[0]
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0 or len(
encoder_inputs_to_schedule) == 0:
# The request cannot be scheduled.
break
watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(request, num_new_tokens,
blocks, watermark):
@@ -183,6 +226,11 @@ class AscendScheduler(Scheduler):
skip_cur_request()
continue
if num_new_tokens > long_prefill_token_threshold \
and long_prefill_budget <= 0:
skip_cur_request()
continue
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
@@ -227,26 +275,41 @@ class AscendScheduler(Scheduler):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_blocks[
request.request_id] = self.kv_cache_manager.get_blocks(
request.request_id)
req_to_new_blocks[
request.request_id] = self.kv_cache_manager.get_blocks(
request.request_id)
# Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
if num_new_tokens > long_prefill_token_threshold:
long_prefill_budget -= 1
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
if self.phase == "decode":
while len(
self.running
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
request = self.finished_prefill_reqs.popleft()
self.running.append(request)
# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
@@ -267,6 +330,16 @@ class AscendScheduler(Scheduler):
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
@@ -322,11 +395,7 @@ class AscendScheduler(Scheduler):
# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
else:
req_to_new_blocks[request.request_id] = new_blocks
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
@@ -342,6 +411,15 @@ class AscendScheduler(Scheduler):
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Record scheduled LoRA requests.
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
@@ -350,7 +428,9 @@ class AscendScheduler(Scheduler):
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(
self.running
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)
@@ -365,67 +445,36 @@ class AscendScheduler(Scheduler):
any_request, len(self.running)))
# Construct the scheduler output.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids)
else:
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, scheduled_spec_decode_tokens,
req_to_new_blocks)
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, scheduled_spec_decode_tokens,
req_to_new_blocks)
scheduled_cached_reqs = cached_reqs_data
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=scheduled_cached_reqs,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs={},
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids, # type: ignore
free_encoder_input_ids=self.encoder_cache_manager.
get_freed_ids(),
structured_output_request_ids={},
grammar_bitmask=None,
)
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=scheduled_cached_reqs,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs={},
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids, # type: ignore
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids={},
grammar_bitmask=None,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=scheduled_cached_reqs,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids, # type: ignore
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids={},
grammar_bitmask=None,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store

View File

@@ -26,3 +26,8 @@ KVConnectorFactory.register_connector(
KVConnectorFactory.register_connector(
"MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector",
"MooncakeConnector")
KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
"MooncakeConnectorV1")

View File

@@ -0,0 +1,457 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import queue
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Sequence
import torch
from vllm.attention import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@dataclass
class ReqMeta:
gpu_block_ids: list[int]
cpu_block_ids: list[int]
num_scheduled_tokens: int
num_computed_tokens: int
num_gpu_computed_tokens: int
num_cpu_computed_tokens: int
def update(self, other: "ReqMeta"):
self.gpu_block_ids.extend(other.gpu_block_ids)
self.cpu_block_ids.extend(other.cpu_block_ids)
self.num_scheduled_tokens = other.num_scheduled_tokens
self.num_computed_tokens = other.num_computed_tokens
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
@dataclass
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
requests: dict[str, ReqMeta]
finished_req_ids: set[str]
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None
self.connector_worker: Optional[
CPUOffloadingConnectorWorker] = None
elif role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = CPUOffloadingConnectorScheduler(
vllm_config)
self.connector_worker = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
# ==============================
# Worker-side methods
# ==============================
def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None:
if self.connector_worker is not None:
assert isinstance(connector_metadata,
CPUOffloadingConnectorMetadata)
self.connector_worker.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
assert self.connector_worker is not None
self.connector_worker.clear_connector_metadata()
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if self.connector_worker is not None:
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
if self.connector_worker is not None:
self.connector_worker.start_load_kv()
def wait_for_layer_load(self, layer_name: str) -> None:
if self.connector_worker is not None:
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
pass
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
assert self.connector_worker is not None
return self.connector_worker.get_finished(), None
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
if self.connector_scheduler is not None:
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if self.connector_scheduler is not None:
return self.connector_scheduler.update_state_after_alloc(request)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
if self.connector_scheduler is not None:
return self.connector_scheduler.build_connector_meta(
scheduler_output)
return KVConnectorMetadata()
def request_finished(
self, request: "Request",
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
if self.connector_scheduler is not None:
self.connector_scheduler.request_finished(request)
return True, None
class CPUOffloadingConnectorScheduler:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorScheduler")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.use_mla = vllm_config.model_config.use_mla
self.num_gpu_computed_tokens: dict[str, int] = {}
self.num_cpu_computed_tokens: dict[str, int] = {}
self.allocated_req_ids: set[str] = set()
self.finished_req_ids: list[str] = []
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.zmq_rpc_client.call("post_init")
if vllm_config.kv_transfer_config is not None:
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
"swap_in_threshold", 0)
else:
self.swap_in_threshold = 0
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
def get_num_new_matched_tokens(
self, ori_request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
"get_matched_num_and_touch", request)
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
self.num_cpu_computed_tokens[
request.request_id] = num_cpu_computed_tokens
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
return num_cpu_computed_tokens - num_computed_tokens, load_async
else:
return 0, load_async
def update_state_after_alloc(self, request: "Request"):
self.allocated_req_ids.add(request.request_id)
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
num_tokens = {}
# process scheduled_new_reqs
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
num_tokens[req_id] = (
req.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
# process scheduled_cached_reqs
cached_reqs = scheduler_output.scheduled_cached_reqs
for idx, req_id in enumerate(cached_reqs.req_ids):
num_tokens[req_id] = (
cached_reqs.num_computed_tokens[idx] +
scheduler_output.num_scheduled_tokens[req_id])
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
self.allocated_req_ids -
scheduler_output.num_scheduled_tokens.keys())
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
num_tokens,
unallocated_req_ids)
metadata = CPUOffloadingConnectorMetadata(
requests={},
finished_req_ids=set(self.finished_req_ids),
)
for req in scheduler_output.scheduled_new_reqs:
req_id = req.req_id
gpu_block_ids = req.block_ids[0]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=req.num_computed_tokens,
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
for idx, req_id in enumerate(cached_reqs.req_ids):
gpu_block_ids = cached_reqs.new_block_ids[idx]
metadata.requests[req_id] = ReqMeta(
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
num_scheduled_tokens=scheduler_output.
num_scheduled_tokens[req_id],
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
self.num_gpu_computed_tokens.clear()
self.num_cpu_computed_tokens.clear()
self.allocated_req_ids.clear()
self.finished_req_ids.clear()
return metadata
def request_finished(self, ori_request: "Request"):
request = copy.deepcopy(ori_request)
request.get_hash_new_full_blocks = None
self.finished_req_ids.append(request.request_id)
# inform metadata server to record request, and free it after finish sending
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
request)
class CPUOffloadingConnectorWorker:
def __init__(self, vllm_config: VllmConfig):
logger.info("init CPUOffloadingConnectorWorker")
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.pp_rank = get_pp_group().rank_in_group
self.tp_group = get_tp_group()
self.tp_rank = self.tp_group.rank_in_group
self.tp_world_size = self.tp_group.world_size
self.use_mla = vllm_config.model_config.use_mla
self.requests: dict[str, ReqMeta] = {}
self.load_stream = torch.npu.Stream()
self.save_stream = torch.npu.Stream()
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
self.load_block_mapping: list[tuple[int, int]] = []
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
self.save_output_queue: queue.Queue[str] = queue.Queue()
self.save_thread = threading.Thread(target=self._save_listener)
self.save_thread.start()
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
# all dp shared the same metadata server, only start the process on data_rank 0
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
config = VllmConfig()
config.cache_config = vllm_config.cache_config
config.parallel_config = vllm_config.parallel_config
config.kv_transfer_config = vllm_config.kv_transfer_config
self.init_metadata_server(config)
self._wait_for_metadata_process_start()
def init_metadata_server(self, vllm_config: VllmConfig):
self.metadata_thread = threading.Thread(
target=MetadataServerProc.run_metadata_server,
args=(vllm_config, ),
)
self.metadata_thread.daemon = True
self.metadata_thread.start()
def _wait_for_metadata_process_start(self):
# TODO: wait for metadata server to start, add a rpc to check if ready
while True:
try:
if self.zmq_rpc_client.call("ready"):
break
except Exception as e:
logger.info(f"wait for metadata server to start, error: {e}")
time.sleep(1)
def bind_connector_metadata(
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
for req_id, req in connector_metadata.requests.items():
if req_id in self.requests:
self.requests[req_id].update(req)
req = self.requests[req_id]
else:
self.requests[req_id] = req
for i in range(req.num_gpu_computed_tokens // self.block_size,
req.num_computed_tokens // self.block_size):
self.load_block_mapping.append(
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
for req_id in connector_metadata.finished_req_ids:
if req_id in self.requests:
self.save_input_queue.put((req_id, self.requests[req_id]))
def clear_connector_metadata(self) -> None:
self.load_block_mapping.clear()
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
self.gpu_kv_caches = kv_caches
model_config = self.vllm_config.model_config
mla_config: Optional[MLAConfig] = None
if model_config.use_mla:
mla_config = MLAConfig(
model_config.hf_text_config.kv_lora_rank,
model_config.hf_text_config.qk_rope_head_dim)
self.cpu_kv_caches = list(
self.zmq_rpc_client.call(
"init_cpu_kv_caches",
self.pp_rank,
self.tp_rank,
get_kv_cache_spec(self.vllm_config),
mla_config,
).values())
def start_load_kv(self) -> None:
self.current_layer = 0
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
self.load_kv_layer(0)
def wait_for_layer_load(self) -> None:
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
self.load_stream.synchronize()
self.current_layer += 1
self.load_kv_layer(self.current_layer)
def load_kv_layer(self, layer: int):
if layer == len(self.gpu_kv_caches):
return
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
cpu_kv_caches = self.cpu_kv_caches[layer]
with torch.npu.stream(self.load_stream):
for cpu_block_id, gpu_block_id in self.load_block_mapping:
for gpu_layer_part, cpu_layer_part in zip(
gpu_kv_caches, cpu_kv_caches):
gpu_layer_part[gpu_block_id].copy_(
cpu_layer_part[cpu_block_id], non_blocking=True)
def get_finished(self) -> set[str]:
done_sending: set[str] = set()
while True:
try:
id = self.save_output_queue.get_nowait()
except queue.Empty:
break
done_sending.add(id)
for id in done_sending:
del self.requests[id]
if self.tp_world_size == 1:
return done_sending
if self.tp_rank == 0:
for req_id in done_sending:
self.done_sending_count[req_id] += 1
other_ranks_finished_ids: list[str] = []
for i in range(1, self.tp_world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
self.done_sending_count[req_id] += 1
all_done_sending: set[str] = set()
for req_id in list(self.done_sending_count.keys()):
if self.done_sending_count[req_id] == self.tp_world_size:
del self.done_sending_count[req_id]
all_done_sending.add(req_id)
# release cpu_kv_cache after request sending finished
# to avoid rpc blocking, use thread to call rpc asynchronously
sending_finished_thread = threading.Thread(
target=self._sending_finished, args=(all_done_sending, ))
sending_finished_thread.daemon = True
sending_finished_thread.start()
return all_done_sending
else:
self.tp_group.send_object(done_sending, dst=0)
return done_sending
def _sending_finished(self, all_done_sending):
for req_id in all_done_sending:
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
def _save_listener(self):
save_block_mapping = []
while True:
req_id, req = self.save_input_queue.get()
for i in range(
req.num_cpu_computed_tokens // self.block_size,
min((req.num_computed_tokens + req.num_scheduled_tokens) //
self.block_size, len(req.cpu_block_ids))):
save_block_mapping.append(
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
with torch.npu.stream(self.save_stream):
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
# non-MLA: kv_layer is list[tensor], typically means [k, v].
if self.use_mla:
start, step = self.tp_rank, self.tp_world_size
else:
start, step = 0, 1
for i in range(start, len(save_block_mapping), step):
gpu_block_id, cpu_block_id = save_block_mapping[i]
for cpu_kv_caches, gpu_kv_caches in zip(
self.cpu_kv_caches, self.gpu_kv_caches.values()):
for cpu_layer_part, gpu_layer_part in zip(
cpu_kv_caches, gpu_kv_caches):
cpu_layer_part[cpu_block_id].copy_(
gpu_layer_part[gpu_block_id],
non_blocking=True)
self.save_stream.synchronize()
self.save_output_queue.put(req_id)
save_block_mapping.clear()
# Copied from vllm_ascend/worker/model_runner_v1.py.
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
return kv_cache_spec

View File

@@ -0,0 +1,202 @@
import time
from collections import defaultdict
from typing import Optional
from vllm.utils import logger, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
PrefixCachingMetrics)
from vllm.v1.core.single_type_kv_cache_manager import \
get_manager_for_kv_cache_spec
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
class CPUCacheStats:
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
self.enable_prefix_caching = enable_prefix_caching
self.log_stats = log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
self.time_sec = int(time.time())
def log(self):
current_time_sec = int(time.time())
# Log the prefix cache hit rate every 10 seconds.
if current_time_sec - self.time_sec >= 10:
self.time_sec = current_time_sec
logger.info("CPU Prefix cache hit rate: %.1f%%",
self.cpu_prefix_cache_metrics.hit_rate * 100)
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def update(self, num_tokens, num_computed_tokens):
# Note the function is called by scheduler
if self.log_stats and self.enable_prefix_caching:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
def set_cache_stats(self, num_tokens, num_computed_tokens):
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.hits = num_computed_tokens
self.prefix_cache_stats.queries = num_tokens
self.prefix_cache_stats.requests = 1
class CPUKVCacheManager:
def __init__(
self,
kv_cache_spec: KVCacheSpec,
num_cpu_blocks: int,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.block_size = kv_cache_spec.block_size
self.num_cpu_blocks = num_cpu_blocks
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.block_pool = BlockPool(self.num_cpu_blocks, True,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=0,
)
# Record kv block hashes, avoid redundant computation.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
# Record blocks touched in get_matched_num_and_touch().
self.req_to_computed_blocks: defaultdict[
str, list[KVCacheBlock]] = defaultdict(list)
# Record the request that failed to allocate.
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
log_stats=True)
# Record request that will be free after finish sending
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
# When the request requires prompt logprobs, we skip prefix caching.
if (request.sampling_params.prompt_logprobs is not None):
return 0, False
request_id = request.request_id
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request_id]
if not block_hashes:
block_hashes = request.block_hashes
self.req_to_block_hashes[request_id] = block_hashes
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.single_type_manager.kv_cache_spec,
use_eagle=self.use_eagle,
)
num_computed_tokens = len(computed_blocks[0]) * self.block_size
self.req_to_computed_blocks[request_id] = computed_blocks[0]
# We should touch these blocks in the concurrent scenarios.
self.block_pool.touch(computed_blocks)
# cup prefix cache status set and log
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
num_computed_tokens)
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
self.cpu_cache_stats.prefix_cache_stats)
self.cpu_cache_stats.log()
return num_computed_tokens, False
def _release_ahead_touch(self, request_id: str):
computed_blocks = self.req_to_computed_blocks[request_id]
if computed_blocks:
self.single_type_manager.block_pool.free_blocks(
reversed(computed_blocks))
self.req_to_computed_blocks.pop(request_id, None)
def allocate_slots(self, req_to_num_tokens: dict[str, int],
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
for request_id in unallocated_req_ids:
self._free_slots(request_id)
req_to_new_blocks = {}
for request_id, num_tokens in req_to_num_tokens.items():
if self.req_failed_to_allocate[request_id]:
continue
new_computed_blocks = self.req_to_computed_blocks[request_id]
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=num_tokens,
new_computed_blocks=new_computed_blocks,
))
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
self._release_ahead_touch(request_id)
self.req_failed_to_allocate[request_id] = True
continue
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request_id, new_computed_blocks)
# Allocate new blocks but do not cache now.
new_blocks = self.single_type_manager.allocate_new_blocks(
request_id, num_tokens)
self.req_to_num_tokens[request_id] = num_tokens
# No need to release ref_cnt because we use officially.
self.req_to_computed_blocks.pop(request_id, None)
req_to_new_blocks[request_id] = [
block.block_id for block in new_computed_blocks + new_blocks
]
return req_to_new_blocks
def record_request_cache_and_free_slots(self, request: Request):
logger.debug(
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
)
self.req_to_free[request.request_id] = request
def cache_and_free_slots(self, request_id: str):
logger.debug(
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
)
if request_id not in self.req_to_free:
logger.Error(
f"request {request_id} not in req_to_free, maybe bug!")
return
request = self.req_to_free[request_id]
if not self.req_failed_to_allocate[request_id]:
self.single_type_manager.cache_blocks(
request,
self.req_to_num_tokens[request_id],
)
self._free_slots(request_id)
logger.debug(
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
del self.req_to_free[request_id]
def _free_slots(self, request_id: str):
# This function is designed to be reentrant.
self._release_ahead_touch(request_id)
self.single_type_manager.free(request_id)
self.req_to_block_hashes.pop(request_id, None)
self.req_to_computed_blocks.pop(request_id, None)
self.req_failed_to_allocate.pop(request_id, None)
self.req_to_num_tokens.pop(request_id, None)

View File

@@ -0,0 +1,269 @@
import math
import os
import pickle
from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.config import KVTransferConfig, VllmConfig
from vllm.utils import get_dtype_size, logger, make_zmq_socket
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
CPUKVCacheManager
@dataclass
class MLAConfig:
nope_dim: int
rope_dim: int
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
if vllm_config.kv_transfer_config is not None:
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
elif kv_transfer_config.kv_connector == "MultiConnector":
ktcs = kv_transfer_config.kv_connector_extra_config.get(
"connectors")
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
return kv_transfer_config
return None
class MetadataServer:
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
DEFAULT_CPU_SWAP_SPACE_GB = 800
class ZMQRPCClient:
def __init__(self, identity=f"worker-{os.getpid()}"):
logger.info(f"metadata client for worker {identity} started")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.DEALER, # type: ignore
bind=False,
identity=identity.encode(),
linger=0)
def call(self, func_name: str, *args, **kwargs) -> Any:
request = (func_name, args, kwargs)
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(request))
_ = self.socket.recv()
response = pickle.loads(self.socket.recv())
result, error = response
if error:
logger.exception(f"call metadata sever error: {error}")
raise error
if func_name == "init_cpu_kv_caches":
(memory_dict, layer_size, layer_dtype, mla_config) = result
# shared_memory_dict is recorded in self to close
self.shared_memory_dict = memory_dict
result = {}
for key, shm in memory_dict.items():
tensor = torch.frombuffer(
shm.buf, dtype=layer_dtype).reshape(layer_size)
if mla_config is not None:
tensor = tensor.split(
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
result[key] = tensor
return result
def __del__(self):
# will be finalized by outer process
self.socket.close()
self.ctx.term()
if hasattr(self, 'shared_memory_dict'):
for shm in self.shared_memory_dict.values():
shm.close()
def __init__(self, vllm_config: VllmConfig):
self.world_size = vllm_config.parallel_config.world_size
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
kv_transfer_config = get_cpu_offload_connector(vllm_config)
assert kv_transfer_config is not None
available_memory_gb = kv_transfer_config.get_from_extra_config(
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
logger.info(f"cpu swap space: {self.available_memory} bytes")
self.ctx = zmq.Context() # type: ignore
self.socket = make_zmq_socket(
self.ctx,
MetadataServer.METADATA_SERVER_ADDRESS,
zmq.ROUTER, # type: ignore
bind=True,
linger=0)
self.functions: dict[str, Callable] = {
"init_cpu_kv_caches": self.init_cpu_kv_caches,
"post_init": self.post_init,
"ready": self.ready,
}
self.shared_memory = {} # type: ignore
self.num_cpu_blocks = -1
@staticmethod
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
try:
existing_shm = SharedMemory(name=name, create=False)
existing_shm.close()
existing_shm.unlink()
except FileNotFoundError:
pass
return SharedMemory(name=name, create=True, size=size)
def ready(self):
return True
def init_cpu_kv_caches(
self,
pp_rank: int,
tp_rank: int,
kv_cache_specs: dict[str, AttentionSpec],
mla_config: MLAConfig,
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
MLAConfig]:
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
# follow the assumption that each layer has the same spec
layer = next(iter(kv_cache_specs.values()))
assert all([
layer.page_size_bytes == any.page_size_bytes
for any in kv_cache_specs.values()
])
# mla shares the same kv cache among different tp
if layer.use_mla:
tp_rank = 0
if (pp_rank, tp_rank) in self.shared_memory:
return self.shared_memory[(pp_rank, tp_rank)]
available_memory = self.available_memory
shared_memory_dict = {}
if layer.use_mla:
available_memory //= self.pipeline_parallel_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
else:
available_memory //= self.world_size
available_memory //= len(kv_cache_specs)
num_blocks = available_memory // layer.page_size_bytes
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
layer.head_size) # type: ignore
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
for layer_name in kv_cache_specs.keys():
# only this format can share during ZeroMQ+pickle
shared_memory_dict[
layer_name] = MetadataServer._safe_create_shared_memory(
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
if layer.use_mla:
assert mla_config is not None
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, mla_config)
else:
self.shared_memory[(pp_rank,
tp_rank)] = (shared_memory_dict, layer_size,
layer.dtype, None)
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
self.num_cpu_blocks = num_blocks
self.layer = layer
return self.shared_memory[(pp_rank, tp_rank)]
def post_init(self):
# different processors in data parallel may call multiple times
if hasattr(self, 'cpu_block_manager'):
return
# do shared_memory() at least once
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
assert self.num_cpu_blocks >= 0
self.cpu_block_manager = CPUKVCacheManager(self.layer,
self.num_cpu_blocks)
self.functions.update({
"get_matched_num_and_touch":
self.cpu_block_manager.get_matched_num_and_touch,
"allocate_slots":
self.cpu_block_manager.allocate_slots,
"record_request_cache_and_free_slots":
self.cpu_block_manager.record_request_cache_and_free_slots,
"cache_and_free_slots":
self.cpu_block_manager.cache_and_free_slots,
})
def serve_step(self):
client_id = self.socket.recv()
_ = self.socket.recv()
raw_msg = self.socket.recv()
try:
func_name, args, kwargs = pickle.loads(raw_msg)
except Exception as e:
response = (None, Exception(f"Invalid request: {str(e)}"))
else:
if func_name in self.functions:
try:
result = self.functions[func_name](*args, **kwargs)
response = (result, None) # type: ignore
except Exception as e:
logger.exception(f"metadata execute error: {e}")
response = (None, e) # type: ignore
else:
response = (None, NameError(f"Function {func_name} not found"))
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
self.socket.send(b"", zmq.SNDMORE) # type: ignore
self.socket.send(pickle.dumps(response))
def shutdown(self):
self.socket.close()
self.ctx.term()
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
"ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
for cached in self.shared_memory.values():
for shm in cached[0].values():
shm.close()
shm.unlink()
class MetadataServerProc:
@staticmethod
def run_metadata_server(vllm_config: VllmConfig):
if (not vllm_config.cache_config.enable_prefix_caching
or get_cpu_offload_connector(vllm_config) is None):
return
shutdown_requested = False
def _signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the worker
# signal.signal(signal.SIGTERM, _signal_handler)
# signal.signal(signal.SIGINT, _signal_handler)
metadata_server: Optional[MetadataServer] = None
try:
metadata_server = MetadataServer(vllm_config)
logger.info("Metadata server started.")
while True:
metadata_server.serve_step()
except SystemExit:
logger.info("Metadata server exiting.")
raise
except Exception as e:
logger.exception(f"Metadata server error: {e}.")
raise e
finally:
if metadata_server is not None:
metadata_server.shutdown()

View File

@@ -1,4 +1,5 @@
import contextlib
import copy
import json
import math
import os
@@ -17,6 +18,7 @@ import torch
import zmq
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
LLMException, LLMRole)
from vllm import envs
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
@@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler():
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
def get_num_new_matched_tokens(
self, request: "Request",
@@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler():
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()
return meta
@@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler():
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
@@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker():
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
self.reqs_to_send: dict[str, float] = {}
def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
@@ -375,16 +387,13 @@ class LLMDataDistCMgrConnectorWorker():
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
decode_tp_rank = decode_msg[1]
decode_tp_size = decode_msg[2]
with self.thread_lock:
if self._increment_task_count(finished_req_id,
decode_tp_rank,
decode_tp_size):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
if finished_req_id in self.reqs_to_send:
self.finished_reqs.add(finished_req_id)
del self.reqs_to_send[finished_req_id]
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
@@ -392,24 +401,6 @@ class LLMDataDistCMgrConnectorWorker():
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)
def _increment_task_count(self, request_id: str, tp_rank: int,
decode_tp_size: int):
if request_id not in self.done_receiving_counts:
self.done_receiving_counts[request_id] = set()
if tp_rank in self.done_receiving_counts[request_id]:
logger.warning(
f"Received duplicate done signal for request {request_id} "
f"from tp rank {tp_rank}. Ignoring.")
return False
self.done_receiving_counts[request_id].add(tp_rank)
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
self.done_receiving_counts.pop(request_id)
logger.info("All transfers completed for request: "
f"{request_id}. Total ranks: "
f"{decode_tp_size}.")
return True
return False
def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
@@ -502,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
assert self.local_agent_metadata is not None
kv_cache_dtype = first_kv_cache.dtype
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
@@ -549,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
elif self.use_sfa:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
cache_k_idx_addr_list = []
k_normed = None
k_pe = None
k_idx = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
1], cache_or_caches[2]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
cache_k_idx_addr_list.append(k_idx.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
cache_k_idx_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_idx = CacheDesc(
len(self.cache_addr[2]), [*k_idx.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=2)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
cache_desc_k_idx)
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
cache_key_k_idx)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
cache_k_idx = self.cache_manager.register_blocks_cache(
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
else:
for cache_or_caches in kv_caches.values():
for cache in cache_or_caches:
@@ -605,6 +651,7 @@ class LLMDataDistCMgrConnectorWorker():
for future in futures:
future.add_done_callback(handle_exception)
self.reqs_to_send.update(metadata.reqs_to_send)
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
@@ -767,24 +814,24 @@ class LLMDataDistCMgrConnectorWorker():
cluster_id = self.add_remote_agent(metadata)
return cluster_id
def send_finish_to_remote(self, host: str, port: int, request_id):
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode([
LLMDataDistCMgrEvent.ReqForFinished,
[request_id, self.tp_rank, self.tp_size]
])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}")
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
for port in ports:
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}"
)
def _read_blocks(
self,
@@ -834,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sfa:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")
@@ -851,7 +930,10 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_port, request_id)
remote_ports = list(
range(remote_port + self.tp_rank,
remote_port + int(remote_tp_size), self.tp_size))
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)
@@ -859,8 +941,19 @@ class LLMDataDistCMgrConnectorWorker():
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
now = time.perf_counter()
with self.thread_lock:
while self.reqs_to_send:
req_id, expires = next(iter(self.reqs_to_send.items()))
if now < expires:
break
logger.warning(
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
if req_id in self.reqs_to_send:
self.finished_reqs.add(req_id)
del self.reqs_to_send[req_id]
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
@@ -891,4 +984,4 @@ def zmq_ctx(socket_type: Any,
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)
ctx.destroy(linger=0)

View File

@@ -1,556 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare the MoE communication method.
This method is called before quant_method.apply to prepare the
communication method. It can be used to initialize any necessary
resources or configurations.
"""
pass
@abstractmethod
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""Finalize the MoE communication method.
This method is called after quant_method.apply to finalize the
communication method. It can be used to clean up any resources or
configurations.
"""
pass
@abstractmethod
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
"""Pre-process before MLP.
Args:
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
Mapping from global expert IDs to local expert IDs.
num_experts (int): Number of local experts (experts on this device).
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
Returns:
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
- permuted_hidden_states (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after permuting
hidden_states based on topk_ids.
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
Number of tokens assigned to each expert.
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
Dynamic scale for each expert, used for quantization.
- group_list_type (int): Type of group list, 0 for `cumsum`
and 1 for `count`. This is mainly for `npu_grouped_matmul`
to determine how to handle the output.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Post-process after MLP.
Args:
mlp_output (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after MLP.
hidden_states (torch.Tensor): Tensor of shape
(num_tokens, hidden_size) to be updated with the final output.
"""
pass
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""When DP size > 1, pad the hidden states and router logits for communication."""
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""When DP size > 1, reduce-scatter the hidden states to get the final output.
When TP size > 1, all-reduce the hidden states to get the final output.
"""
if self.moe_config.dp_size > 1:
hidden_states = data_parallel_reduce_scatter(hidden_states, dim=0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor, # noqa: F841
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
self.topk_weights = topk_weights
self.topk_ids = topk_ids
first_expert_idx = 0
if expert_map is not None:
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
mask = expert_map[topk_ids] != -1
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
self.topk_weights = torch.where(mask, topk_weights, 0.0)
first_expert_idx = self.moe_config.ep_rank * num_experts
last_expert_idx = first_expert_idx + num_experts
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.moe_config.experts_per_token,
expert_num=self.moe_config.num_experts,
expert_tokens_num_type=1, # Only support `count` mode now
expert_tokens_num_flag=True, # Output `expert_tokens`
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
self.expanded_row_idx = expanded_row_idx
permuted_hidden_states = permuted_hidden_states
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
permuted_tokens=mlp_output,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
class NativeAllGatherCommImpl(AllGatherCommImpl):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=hidden_states.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.moe_config.experts_per_token).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
topk_weights.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=hidden_states.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, None, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
# NOTE: We do not need to use mc2_group's rank and world size
# because ep_group and mc2_group basically have the same init params.
# We only init another group because of the restriction of MC2:
# "No other groups can be used in the same process as the MC2 group."
self.mc2_comm_name = get_mc2_group().device_group._get_backend(
torch.device("npu")).get_hccl_comm_name(self.moe_config.ep_rank)
# Feature flags
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
self.need_extra_args = self.is_ascend_a3
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""The target_pad_length is calculated in forward_context, here we pad the
hidden states and router logits. And if TP size > 1, we also need to split
the tensors accordingly.
"""
self.num_tokens, _ = hidden_states.shape
forward_context = get_forward_context()
self.mc2_mask = forward_context.mc2_mask
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_mc2_mask = torch.tensor_split(self.mc2_mask,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.mc2_mask = split_mc2_mask[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
# Store tensors needed for post_process
self.topk_ids = topk_ids
self.topk_weights = topk_weights.to(torch.float32)
dispatch_kwargs = {
"x": hidden_states,
"expert_ids": self.topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"scales": None,
"quant_mode": 2 if apply_a8_quantization else 0,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.need_extra_args:
dispatch_kwargs.update({
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
dispatch_kwargs.update({
"x_active_mask": self.mc2_mask,
})
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
(
permuted_hidden_states,
dynamic_scale,
self.assist_info_for_combine,
expert_tokens,
self.ep_recv_counts,
self.tp_recv_counts,
) = dispatch(**dispatch_kwargs)[:6]
group_list_type = 1
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
combine_kwargs = {
"expand_x": mlp_output,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.moe_config.num_experts,
"global_bs": 0,
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.mc2_comm_name,
"ep_world_size": self.moe_config.ep_size,
"ep_rank_id": self.moe_config.ep_rank,
}
if self.enable_dispatch_v2:
combine_kwargs[
"assist_info_for_combine"] = self.assist_info_for_combine
else:
combine_kwargs["expand_idx"] = self.assist_info_for_combine
if self.need_extra_args:
combine_kwargs.update({
"tp_send_counts": self.tp_recv_counts,
"group_tp": self.mc2_comm_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
combine_kwargs.update({
"x_active_mask": self.mc2_mask,
})
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
hidden_states[:] = combine(**combine_kwargs)
class AlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def __init__(self, moe_config: Optional[FusedMoEConfig]):
super().__init__(moe_config)
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
get_token_dispatcher
self.token_dispatcher = get_token_dispatcher(
"TokenDispatcherWithAll2AllV")
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
# tp_size and tp_rank.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""If TP size > 1, all-gather the hidden states to get the final output.
Also, unpad the hidden states if needed.
"""
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
def permute(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
apply_a8_quantization: bool,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
results = self.token_dispatcher.token_dispatch(
hidden_states,
topk_weights,
topk_ids,
None,
log2phy=None,
with_quant=apply_a8_quantization)
return results["hidden_states"], results["group_list"], results[
"dynamic_scale"], results["group_list_type"]
def unpermute(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)

View File

@@ -0,0 +1,447 @@
import array
import hashlib
import json
import os
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import cdiv, logger
from vllm.v1.core.sched.output import NewRequestData
@dataclass
class MooncakeEngineMetadata:
"""name of the LLM model"""
model_name: str
""" world size when running under a distributed setting """
world_size: int
""" worker id when running under a distributed setting """
worker_id: int
""" the format of kv tensors """
kv_dtype: torch.dtype
""" the shape of kv tensors """
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
kv_shape: tuple[int, int, int, int, int]
block_size: int = 128
""" whether use MLA"""
use_mla: bool = False
@dataclass(order=True)
class MooncakeEngineKey:
model_name: str
world_size: int
worker_id: int
chunk_hash: str
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}")
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerMooncakeEngineKey(
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
layer_id,
))
return keys
def to_dict(self):
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
return {
"__type__": "CacheEngineKey",
"model_name": self.model_name,
"world_size": self.world_size,
"worker_id": self.worker_id,
"chunk_hash": self.chunk_hash,
}
@staticmethod
def from_dict(d):
return MooncakeEngineKey(
model_name=d["model_name"],
world_size=d["world_size"],
worker_id=d["worker_id"],
chunk_hash=d["chunk_hash"],
)
@dataclass(order=True)
class LayerMooncakeEngineKey(MooncakeEngineKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
class ChunkedTokenDatabase():
def __init__(
self,
metadata: MooncakeEngineMetadata,
):
self.metadata = metadata
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return MooncakeEngineKey(
self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id,
chunk_hash,
)
def _hash(
self,
tokens: Union[torch.Tensor, List[int]],
prefix_hash: str,
) -> str:
# TODO: change it to a more efficient hash function
if isinstance(tokens, torch.Tensor):
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
elif isinstance(tokens, list):
tokens_bytes = array.array("I", tokens).tobytes()
return hashlib.sha256(prefix_hash.encode("ascii") +
tokens_bytes).hexdigest()
def _chunk_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
) -> Iterable[Union[torch.Tensor, List[int]]]:
"""
Chunk the tokens into chunks of size self.metadata.block_size.
:param tokens: the input tokens, with shape [seq_len]
device: the target device after chunking
:return: a generator of chunks of tokens, each with
shape [metadata.block_size]
"""
for i in range(0, len(tokens), self.metadata.block_size):
yield tokens[i:i + self.metadata.block_size]
def _prefix_hash(
self,
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
) -> Iterable[str]:
prefix_hash = ''
for token_chunk in token_chunks:
prefix_hash = self._hash(token_chunk, prefix_hash)
yield prefix_hash
def process_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
num_falses = mask.numel() - mask.long().sum().item()
else:
num_falses = 0
if num_falses % self.metadata.block_size != 0:
raise ValueError(
"The number of Falses in the mask is not a multiple of the chunk size."
)
total_len = len(tokens)
token_chunks = self._chunk_tokens(tokens)
prefix_hashes = self._prefix_hash(token_chunks)
start_idx = 0
for chunk_id, hash_val in enumerate(prefix_hashes):
start_idx = chunk_id * self.metadata.block_size
end_idx = min(start_idx + self.metadata.block_size, total_len)
if start_idx < num_falses:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in mooncake
mooncake_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class SaveSpec:
# Skip already saved tokens
skip_leading_tokens: int
# Whether the scheduler allow us to save the tokens
can_save: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_ids: list[int]
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
# list[list[int]]
# Need to check the type of request.block_ids
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_ids.extend(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_ids: torch.Tensor
block_ids: list[int]
# # Slot mapping if exchange for block_id
# slot_mapping: torch.Tensor
# Skip save or not
save_spec: Optional[SaveSpec] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_ids = tracker.token_ids
input_token_len = len(input_token_ids)
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
skip_leading_tokens = tracker.num_saved_tokens
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
# Calculate the token ids and slot mappings for load and save
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.mooncake_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_ids=token_ids,
block_ids=tracker.allocated_block_ids,
save_spec=save_spec,
load_spec=load_spec,
is_last_chunk=is_last_chunk,
)
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerMooncakeEngineKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: int
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=config.get("global_segment_size", 3355443200),
local_buffer_size=config.get("local_buffer_size", 1073741824),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)

View File

@@ -0,0 +1,251 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm_ascend.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
class KVTransferThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.tp_rank = tp_rank
self.tp_size = tp_size
self.m_store = m_store
self.ready_event = ready_event
self.kv_caches_base_addr = local_kv_caches_base_addr
self.block_len = block_len
self.token_database = token_database
self.block_size = block_size
self.done_task_lock = threading.Lock()
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def add_request(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
is_last_chunk: Optional[bool] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"tokens": tokens,
"block_ids": block_ids,
"mask": mask,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheSendingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
torch.npu.current_stream().synchronize()
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.put(key, addr, size)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
num_layers: int):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
torch.npu.current_stream().synchronize()
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.put(key, addr, size)
if req_meta.layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.get(key, addr, size)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -0,0 +1,489 @@
# Standard
import math
import threading
import time
from typing import Generator, List, Optional, Union
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.utils import get_kv_cache_torch_dtype, logger
from vllm_ascend.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
MooncakeEngineMetadata)
from vllm_ascend.distributed.mooncake.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
class MooncakeEngine:
#The main class for the cache engine.
def __init__(
self,
vllm_config: VllmConfig,
use_layerwize: bool,
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla):
self.use_mla = True
self.use_layerwise = use_layerwize
self.tp_rank = parallel_config.rank
self.tp_size = parallel_config.tensor_parallel_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
# -1) != first_kv_cache_tuple[1].size(-1)
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_dtype = get_kv_cache_torch_dtype(
vllm_config.cache_config.cache_dtype, model_config.dtype)
self.hidden_dim_size = num_kv_head * head_size
if self.use_mla:
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
else:
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
head_size)
self.metadata = MooncakeEngineMetadata(
model_config.model,
parallel_config.world_size,
parallel_config.rank,
kv_dtype,
kv_shape,
self.block_size,
self.use_mla,
)
self.token_database = ChunkedTokenDatabase(self.metadata)
self.m_store = Mooncakestore(parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
# TODO(tms): Find a more robust way to detect and handle MLA
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
self.m_store.set_kv_caches(kv_caches.values())
self.kv_caches_base_addr = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending,
self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event, self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
tokens = request.token_ids
req_id = request.req_id
if (load_spec.mooncake_cached_tokens % self.block_size
!= 0) and (load_spec.mooncake_cached_tokens
== tokens.shape[0] - 1):
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
else:
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
masked_token_count = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
token_mask = torch.ones_like(tokens, dtype=torch.bool)
token_mask[:masked_token_count] = False
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
tokens,
request.block_ids,
token_mask,
)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
else:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
request.block_ids,
token_mask,
)
def wait_for_layer_load(self) -> None:
"""MooncakeConnector does not do layerwise saving."""
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: MooncakeConnectorMetadata) -> None:
"""MooncakeConnector does not save explicitly."""
if self.current_layer == 0:
self.layerwise_storers = []
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_ids,
mask=store_mask,
block_ids=request.block_ids,
)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
try:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
"""MooncakeConnector does not save explicitly."""
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
continue
token_ids = request.token_ids
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_ids,
request.block_ids,
store_mask,
request.is_last_chunk,
)
def retrieve_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the KV transfer which
will be passed into the npu_transfer.
return: A generator that yields Optional[torch.Tensor]. The tensor will
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
if mask is not None:
num_required_tokens = torch.sum(mask).item()
else:
num_required_tokens = len(tokens)
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer)
ret_mask[start:end] = True
if keys:
# Transpose the keys into layer major format
keys = [list(row) for row in zip(*keys)] # [num_layer,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
if not first_flag:
is_finish = self.get_event.wait(timeout=3) #try---cache
if not is_finish:
logger.info("Layerwise get failed")
self.get_event.clear()
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
first_flag = False
yield None
else:
# If no cache are found, we still need to yield to avoid
# `StopIteration`
for layer_id in range(self.num_layers):
yield None
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {len(tokens)} tokens")
yield ret_mask
def store_layer(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
:param torch.Tensor tokens: The tokens of the corresponding KV caches.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched.
:param **kwargs: The additional arguments for the storage backend which
will be passed into the gpu_connector.
return: A generator that yields None. In the first iteration, the
generator allocates the memory objects for all layers and moves
the KV cache of the first layer from GPU to CPU. In the next
iterations, it moves the KV cache of layer i from GPU to the memory
objects (on CPU) and puts the memory objects of layer i-1 to the
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
if mask is not None:
num_stored_tokens = torch.sum(mask).item()
else:
num_stored_tokens = len(tokens)
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
keys.append(keys_multi_layer) #[block_num,layer_num]
if keys:
keys = [list(row) for row in zip(*keys)] #[layer_num,block_num]
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
else:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role in ['kv_producer', 'kv_both'] else set())
done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr]
)
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d, tp_rank:%d", len(done_sending), len(done_recving),
self.tp_rank)
return done_sending, done_recving
def wait_layer_transfer_finish(self):
time.sleep(10)
pass
def lookup(
self,
tokens: Union[torch.Tensor, List[int]],
use_layerwise: bool,
) -> int:
"""
Checks the existence of KV cache of the tokens from the cache engine.
:param tokens: the input tokens, with shape [seq_len]
:return: An int indicating how many prefix tokens are cached.
"""
end = 0
for start, end, key in self.token_database.process_tokens(tokens):
try:
if use_layerwise:
keys = []
keys_multi_layer = key.split_layers(self.num_layers)
for key in keys_multi_layer:
keys.append(key.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
else:
res = self.m_store.exists(key)
if res == 1:
continue
else:
return start
except Exception as e:
logger.warning(f"Remote connection failed in contains: {e}")
return start
# all tokens where found, return the maximal end
return end
def close(self) -> None:
"""Close the cache engine and free all the resources"""
self.m_store.close()

View File

@@ -0,0 +1,88 @@
# Standard
import os
# Third Party
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.utils import logger
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
from .config_data import MooncakeStoreConfig
METADATA_BYTES_LEN = 24
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
class Mooncakestore():
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
tp_rank = get_tensor_model_parallel_rank()
tp_size = parallel_config.tensor_parallel_size
dp_rank = parallel_config.data_parallel_rank_local
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
if not all_device_ids:
device_ids_list = list(
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
else:
device_ids_list = list(map(int, all_device_ids.split(',')))
assert len(device_ids_list) > tp_rank
device_id = device_ids_list[tp_rank]
self.config = MooncakeStoreConfig.load_from_env()
if self.config.protocol == "ascend":
local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
else:
local_hostname = self.config.local_hostname
self.store = MooncakeDistributedStore()
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol, self.config.device_name,
self.config.master_server_address)
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def set_kv_caches(self, kvcache):
self.kvcache = list(kvcache)
def exists(self, key: MooncakeEngineKey) -> bool:
return self.store.is_exist(key.to_string()) == 1
def batch_exists(self, keys: list[str]) -> list[bool]:
return self.store.batch_is_exist(keys)
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
expect_res = sum(size)
key_str = key.to_string()
try:
res = self.store.batch_get_into_ascend(key_str, addr, size)
if res[0] != expect_res:
logger.error(f"Failed to get key: [{key_str}] .")
except Exception:
logger.error(f"Failed to get key: [{key_str}] .")
return res
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
key_str = key.to_string()
try:
ret = self.store.batch_put_from_ascend(key_str, addr, size)
if ret[0] != 0:
logger.error(f"Failed to put key {key_str}.")
except Exception:
logger.error(f"Failed to put key {key_str}.")
return ret
def close(self):
self.store.close()
logger.info("Closed the mooncake store connection")

View File

@@ -0,0 +1,484 @@
import threading
from typing import Any, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.utils import logger, make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm_ascend.distributed.mooncake.config_data import (
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine
class MooncakeConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
vllm_config, self.use_layerwise)
else:
self.connector_worker = MooncakeEngine(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = MooncakeLookupServer(
self.connector_worker, vllm_config, self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._get_connector_metadata(),
MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeStoreConnector does not do layerwise saving."""
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeStoreConnector does not save explicitly."""
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
"""MooncakeStoreConnector does not save explicitly."""
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
self.connector_worker.wait_layer_transfer_finish()
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
def get_zmq_rpc_path_mooncake(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
"mooncake_rpc_port", 0)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
class MooncakeStoreConnectorV1Scheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.client = MooncakeLookupClient(vllm_config)
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
# request_id -> (vllm cached tokes, mooncake cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
# request_id -> full_token_ids
self._request_trackers: dict[str, RequestTracker] = {}
# Whether to discard partial chunks
self._discard_partial_chunks = (
vllm_config.kv_transfer_config.get_from_extra_config(
"discard_partial_chunks", True))
self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {}
self._unfinished_request_ids: set[str] = set()
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self._discard_partial_chunks:
token_block_end = len(request.prompt_token_ids
) // self._block_size * self._block_size
token_ids = torch.tensor(
request.prompt_token_ids[:token_block_end])
else:
token_ids = torch.tensor(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_ids)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
need_to_allocate,
)
if need_to_allocate <= 0:
return 0, False
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
mooncake_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, not self.use_layerwise
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
local_block_ids = []
if num_external_tokens > 0:
local_block_ids = blocks.get_block_ids()[0]
self._unfinished_requests[request.request_id] = (request,
local_block_ids)
self._unfinished_request_ids.add(request.request_id)
if request.request_id not in self.load_specs:
# No KV tokens from external KV cache, return
return
if num_external_tokens == 0:
# No need to load anything
self.load_specs[request.request_id].can_load = False
return
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].mooncake_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
self.load_specs[request.request_id].can_load = True
def build_connector_meta(
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save = self.kv_role == "kv_consumer"
for finished_req_id in scheduler_output.finished_req_ids:
self._request_trackers.pop(finished_req_id, None)
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.remove(finished_req_id)
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
load_spec = self.load_specs.pop(request.req_id, None)
num_tokens_to_compute = (
request.num_computed_tokens +
scheduler_output.num_scheduled_tokens[request.req_id])
request_tracker = RequestTracker.from_new_request(
request, num_tokens_to_compute)
self._request_trackers[request.req_id] = request_tracker
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if isinstance(cached_reqs, list) and not force_skip_save:
for i, req in enumerate(cached_reqs):
request_tracker = self._request_trackers[req.req_id]
request_tracker.update(req.new_token_ids, req.new_block_ids)
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(req.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
elif not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = len(request_tracker.token_ids)
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
else:
raise ValueError(
f"Request {req_id} is not in _unfinished_requests, "
f"but it is scheduled to be cached")
new_block_ids = cached_reqs.new_block_ids[i]
if not new_block_ids:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if len(request_tracker.token_ids) > len(
request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(request.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
request_ids = [
req.req_id for req in scheduler_output.scheduled_new_reqs
]
for request_id, (request,
block_ids) in self._unfinished_requests.items():
if request_id not in request_ids and request_id not in cached_reqs.req_ids:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.mooncake_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
self._request_trackers[request_id] = request_tracker
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=None,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
if self.kv_role == "kv_consumer":
return False, None
if self._request_trackers[request.request_id].num_saved_tokens <= 0:
return False, None
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(block_ids), request.request_id)
return delay_free_blocks, None
class MooncakeLookupClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REQ, # type: ignore[attr-defined]
bind=False,
)
def lookup(self, token_ids: torch.Tensor) -> int:
request = self.encoder.encode(token_ids)
self.socket.send_multipart(request, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
def close(self):
self.socket.close(linger=0)
class MooncakeLookupServer:
def __init__(
self,
mooncake_engine: MooncakeEngine,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.mooncake_engine = mooncake_engine
self.running = True
def process_request():
while self.running:
frames = self.socket.recv_multipart(copy=False)
token_ids = self.decoder.decode(frames)
result = self.mooncake_engine.lookup(token_ids, use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -11,7 +11,7 @@ from collections import defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
import msgspec
import numpy as np
@@ -19,6 +19,7 @@ import numpy.typing as npt
import torch
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
@@ -29,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -67,12 +69,16 @@ class KVCacheTaskTracker:
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
def update_done_task_count(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
self._remove_delayed_requests(request_id)
if request_id in self.delayed_free_requests:
self._remove_delayed_requests(request_id)
else:
self.record_finished_requests.add(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -90,7 +96,10 @@ class KVCacheTaskTracker:
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
self.delayed_free_requests.append((request_id, delay_start_time))
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
@@ -98,10 +107,11 @@ class KVCacheTaskTracker:
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id, delay_start_time = self.delayed_free_requests[0]
request_id = next(iter(self.delayed_free_requests))
delay_start_time = self.delayed_free_requests[request_id]
if (current_time - delay_start_time
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
self.delayed_free_requests.popleft()
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popitem(last=False)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
@@ -110,8 +120,7 @@ class KVCacheTaskTracker:
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests = deque(
(r, t) for r, t in self.delayed_free_requests if r != request_id)
self.delayed_free_requests.pop(request_id)
class KVCacheSendingThread(threading.Thread):
@@ -230,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
@@ -341,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
block_len = (self.block_len[k % 2]
if self.use_mla else self.block_len[0])
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
src = src_layer_base_addr + local_block_ids[0] * block_len
@@ -559,6 +573,7 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing Mooncake Scheduler %s", engine_id)
@@ -718,7 +733,7 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
return self._decode_tp_size
else:
# TODO support mha and gqa
@@ -782,10 +797,12 @@ class MooncakeConnectorWorker:
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
self._initialize(
hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \
+ str(self.device_id),
device_name=None)
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', False):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
self._initialize(hostname=hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.
@@ -837,7 +854,9 @@ class MooncakeConnectorWorker:
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -851,6 +870,21 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -861,8 +895,9 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@@ -874,9 +909,16 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]

View File

@@ -11,7 +11,7 @@ from vllm_ascend.ascend_config import get_ascend_config
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
@@ -20,6 +20,12 @@ def get_mc2_group() -> GroupCoordinator:
return _MC2
def get_otp_group() -> GroupCoordinator:
assert _OTP is not None, (
"output tensor parallel group is not initialized")
return _OTP
def get_lmhead_tp_group() -> GroupCoordinator:
assert _LMTP is not None, (
"lm head tensor parallel group is not initialized")
@@ -74,6 +80,20 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
backend,
group_name="mlp_tp")
# If oproj tensor parallel size is set, we will create a group for it.
otp_size = get_ascend_config().oproj_tensor_parallel_size
if otp_size is not None:
group_ranks = []
global _OTP
num_oproj_tensor_parallel_groups: int = (world_size // otp_size)
for i in range(num_oproj_tensor_parallel_groups):
ranks = list(range(i * otp_size, (i + 1) * otp_size))
group_ranks.append(ranks)
_OTP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="otp")
lmhead_tensor_parallel_size = get_ascend_config(
).lmhead_tensor_parallel_size
if lmhead_tensor_parallel_size is not None:
@@ -117,3 +137,8 @@ def destroy_ascend_model_parallel():
if _LMTP:
_LMTP.destroy()
_LMTP = None
global _OTP
if _OTP:
_OTP.destroy()
_OTP = None

View File

@@ -1,248 +0,0 @@
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapts from: Megatron/megatron/core/tensor_parallel/mappings.py.
# This file is a part of the vllm-ascend project.
import torch
def _gather_along_first_dim(input_, group, output_split_sizes=None):
"""Gather tensors and concatenate along the first dimension.
Args:
input_tensor (torch.Tensor):
A tensor to be gathered.
output_split_sizes (List[int], optional):
A list specifying the sizes of the output splits along the first dimension.
If None, equal splitting is assumed. Default: None.
Returns:
torch.Tensor: Gathered tensor.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output
def _gather_along_last_dim(input_, group):
"""Gather tensors and concatenate along the last dimension."""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
tensor_list = output.chunk(world_size, dim=0)
output = torch.cat(tensor_list, dim=-1).contiguous()
return output
def _reduce_scatter_along_first_dim(input_,
group,
input_split_sizes=None,
use_global_buffer=False):
"""Reduce-scatter the input tensor across model parallel group.
Args:
input_ (torch.Tensor): The input tensor to be reduce-scattered.
input_split_sizes (List[int], optional): A list specifying the sizes of
the input splits along the first dimension for each rank. If None,
equal splitting is assumed. Default: None.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if input_split_sizes is None:
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.reduce_scatter_tensor(output,
input_.contiguous(),
group=group)
else:
rank = torch.distributed.get_rank(group)
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
output = torch.empty_like(input_tensor_list[rank])
torch.distributed.reduce_scatter(output,
input_tensor_list,
group=group)
return output
def _reduce_scatter_along_last_dim(input_, group):
"""Reduce-scatter tensors on the last dimension."""
world_size = torch.distributed.get_world_size(group)
target_shape = list(input_.size())
target_shape[-1] = target_shape[-1] // world_size
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(input_,
split_size_or_sections=input_.shape[-1] //
world_size,
dim=1)
concat_tensor = torch.cat(split_tensors, dim=0)
output = _reduce_scatter_along_first_dim(concat_tensor,
group).reshape(target_shape)
return output
def all_gather_last_dim_from_tensor_parallel_region(input_, group):
"""Wrapper for autograd function: forward: AG, backward RS <last dim>"""
return _gather_along_last_dim(input_, group)
def reduce_scatter_to_sequence_parallel_region(input_,
group,
input_split_sizes=None):
"""Wrapper for autograd function: forward: RS, backward AG <first dim>"""
return _reduce_scatter_along_first_dim(input_, group, input_split_sizes)
def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group):
"""Wrapper for autograd function: forward: RS, backward AG: AG <last dim>"""
return _reduce_scatter_along_last_dim(input_, group)
def gather_from_sequence_parallel_region(
input_,
group,
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)
def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None):
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.npu.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
def all_to_all_sp2hp(input_, group):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens/TP, H] to [num_tokens, H/TP].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the sequence
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
"""
if group is None:
return input_
world_size = torch.distributed.get_world_size(group=group)
tp_group = group
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(input_,
split_size_or_sections=input_.shape[-1] //
world_size,
dim=1)
concat_tensor = torch.cat(split_tensors, dim=0)
output = all_to_all(tp_group, concat_tensor)
return output
def all_to_all_hp2sp(input_, group):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens, H/TP] to [num_tokens/TP, H].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the hidden
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
"""
if group is None:
return input_
world_size = torch.distributed.get_world_size(group=group)
input_ = input_.reshape(-1, input_.shape[-1])
tp_group = group
input_exchanged = all_to_all(tp_group, input_)
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
split_tensors = torch.split(
input_reshaped,
split_size_or_sections=input_reshaped.shape[0] // world_size,
dim=0)
output = torch.cat(split_tensors, dim=-1)
return output

View File

@@ -131,6 +131,26 @@ env_variables: Dict[str, Callable[[], Any]] = {
# this feature is supported in A2, and eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
# buffer size for gate up prefetch
"VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE":
lambda: int(
os.getenv("VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE", 18 * 1024 * 1024)),
# buffer size for down proj prefetch
"VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE":
lambda: int(
os.getenv("VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE", 18 * 1024 * 1024)),
# Whether to enable dense model and general optimizations for better performance.
# Since we modified the base parent class `linear`, this optimization is also applicable to other model types.
# However, there might be hidden issues, and it is currently recommended to prioritize its use with dense models.
"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE", '0'))),
# Whether to enable mlp optimize when tensor parallel is enabled.
# this feature in eager mode will get better performance.
"VLLM_ASCEND_ENABLE_MLP_OPTIMIZE":
@@ -139,11 +159,16 @@ env_variables: Dict[str, Callable[[], Any]] = {
# caused by the initialization of the Mooncake connector.
"PHYSICAL_DEVICES":
lambda: os.getenv("PHYSICAL_DEVICES", None),
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
"MSMONITOR_USE_DAEMON":
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
# Timeout (in seconds) for delayed KVCache block release. In the prefill
# node, if a request is marked for delayed KV block release and the blocks
# are not freed within this timeout, they will be forcibly released.
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
}
# end-env-vars-definition
@@ -157,4 +182,4 @@ def __getattr__(name: str):
def __dir__():
return list(env_variables.keys())
return list(env_variables.keys())

View File

View File

@@ -0,0 +1,44 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
from abc import abstractmethod
from typing import Any
class EplbAdaptor():
def __init__(self, **args):
pass
@abstractmethod
def get_rank_expert_workload(self):
raise NotImplementedError
@abstractmethod
def get_init_expert_map(self, num_moe_layers: Any) -> Any:
raise NotImplementedError
@abstractmethod
def do_update_expert_map(self, layer_id: Any,
updated_expert_map: Any) -> Any:
raise NotImplementedError
@abstractmethod
def do_update_expert_weight(self, layer_id: Any,
local_expert_to_replace: Any,
buffer_tensor_id: Any) -> Any:
raise NotImplementedError

View File

@@ -0,0 +1,289 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this adaptor.
import json
from typing import Any
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
class VllmEplbAdaptor(EplbAdaptor):
def __init__(self, model, **args):
super().__init__(**args)
self.model = model
self.rank_id = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_dict = dict(self.model.named_parameters())
if self.model.config.model_type == "qwen3_moe":
self.num_dense_layers = 0
self.global_expert_num = self.model.config.num_experts
else:
self.num_dense_layers = self.model.config.first_k_dense_replace
self.global_expert_num = self.model.config.n_routed_experts
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
self.expert_map_per_layer = dict(
) # reference to expert map on device for expert map update
self.expert_map_per_layer_cpu = dict(
) # copy of expert map on CPU to avoid device synchronize frequently
for layer_idx in range(self.num_moe_layers):
self.expert_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_expert_map(self.num_dense_layers + layer_idx)
# TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
num_buffer_tensor = torch.where(
self.expert_map_per_layer[self.num_dense_layers] != -1)[0].numel()
self.buffer_tensor_list: list[list[Any]] = [
[] for _ in range(num_buffer_tensor)
]
self.init_buffer_tensor(num_buffer_tensor)
self.expert_param_per_layer = dict()
self.init_expert_param_per_layer()
self.log2phy_map_per_layer = dict()
for layer_idx in range(self.num_moe_layers):
self.log2phy_map_per_layer[self.num_dense_layers + layer_idx] = \
self.model.get_log2phy_map(self.num_dense_layers + layer_idx)
self.all_topk_ids = []
def init_buffer_tensor(self, num_buffer_tensor):
for name in self.expert_weight_names:
complete_name = "model.layers." + str(
self.num_dense_layers) + ".mlp.experts." + name
expert_tensor = self.param_dict[complete_name].data[
0:num_buffer_tensor]
buffer_tensors = torch.empty_like(expert_tensor)
for buffer_id in range(num_buffer_tensor):
self.buffer_tensor_list[buffer_id].append(
buffer_tensors[buffer_id])
def init_expert_param_per_layer(self):
num_local_expert = self.param_dict["model.layers." + str(self.num_dense_layers) + \
".mlp.experts." + self.expert_weight_names[0]].data.shape[0]
for moe_layer_id in range(self.num_moe_layers):
layer_idx = self.num_dense_layers + moe_layer_id
self.expert_param_per_layer[layer_idx] = list()
for local_expert_id in range(num_local_expert):
self.expert_param_per_layer[layer_idx].append([
self.param_dict["model.layers." + str(layer_idx) +
".mlp.experts." +
name].data[local_expert_id]
for name in self.expert_weight_names
])
def get_rank_expert_workload(self) -> torch.Tensor:
self.moe_load = self.model.get_all_moe_loads()
return self.moe_load
def get_init_expert_map(self, num_moe_layers):
expert_map = self.model.get_all_expert_map(num_moe_layers)
if dist.is_initialized():
world_size = dist.get_world_size()
gathered = torch.empty(
(world_size, *expert_map.shape), # [W, L, E]
dtype=expert_map.dtype,
device=expert_map.device)
dist.all_gather_into_tensor(gathered, expert_map)
all_maps = gathered.permute(1, 0, 2)
all_expert_maps = all_maps.cpu()
for layer_idx in range(num_moe_layers):
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \
all_expert_maps[layer_idx][self.rank_id]
return all_expert_maps
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
try:
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(
expert_map_path)
expert_map_all = self.local2global(expert_map_tensor)
except (TypeError, FileNotFoundError, OSError):
expert_map_all = self.determine_expert_map_all()
for layer_idx in range(num_moe_layers):
if self.model.config.model_type == "qwen3_moe":
self.expert_map_per_layer_cpu[layer_idx] = \
expert_map_all[layer_idx][self.rank_id]
else:
self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \
expert_map_all[layer_idx][self.rank_id]
return expert_map_all
def _expert_file_to_tensor(self, expert_map_path: str):
with open(expert_map_path, "r") as f:
data = json.load(f)
layers_num = data["moe_layer_count"]
gpus_num = data["layer_list"][0]["device_count"]
tensor_data = []
for layer in data["layer_list"]:
device_data = []
for device in layer["device_list"]:
device_data.append(device["device_expert"])
tensor_data.append(device_data)
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
return expert_map_tensor, layers_num, gpus_num
logger.error(f"failed to read expert_map_path: {expert_map_path}")
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
if self.rank_id == 0:
num_local_experts = expert_maps.max() + 1
expert_maps_local = self.global2local(expert_maps,
num_local_experts)
expert_maps_list = expert_maps_local.tolist()
record: dict[str, Any] = {
"moe_layer_count": len(expert_maps_list),
"layer_list": []
}
for layer_idx, layer_data in enumerate(expert_maps_list):
layer_record: dict[str, Any] = {
"layer_id": layer_idx,
"device_count": len(layer_data),
"device_list": []
}
for device_idx, experts in enumerate(layer_data):
device_record = {
"device_id": device_idx,
"device_expert": experts
}
layer_record["device_list"].append(device_record)
record["layer_list"].append(layer_record)
with open(expert_map_record_path, "w") as f:
json.dump(record, f, indent=4)
def do_update_expert_map(self, layer_id, updated_expert_map):
self.expert_map_per_layer[layer_id] = updated_expert_map.clone()
self.expert_map_per_layer_cpu[layer_id] = updated_expert_map.clone()
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
buffer_tensor_id):
for expert_tensor, buffer_tensor in zip(
self.expert_param_per_layer[layer_id][local_expert_to_replace],
self.buffer_tensor_list[buffer_tensor_id]):
expert_tensor = buffer_tensor.clone()
logger.debug(f"Expert tensor shape is :{expert_tensor.shape}")
def do_update_log2phy_map(self, layer_id, updated_log2phy_map):
if self.log2phy_map_per_layer[layer_id] is not None:
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
def global2local(self, placement: torch.Tensor,
E_local: int) -> torch.Tensor:
L, G, _ = placement.shape
device = placement.device
pt_local = torch.full((L, G, E_local),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
slot_idx = placement[l_idx, g_idx, k_idx]
pt_local[l_idx, g_idx, slot_idx] = k_idx
return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape
device = placement_local.device
max_id = torch.max(placement_local)
E_global = (max_id + 1).item() if max_id >= 0 else 0
if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
gid_idx = placement_local[l_idx, g_idx, slot_idx]
placement_global[l_idx, g_idx, gid_idx] = slot_idx
return placement_global
def determine_expert_map_all(self):
if self.world_size == 1:
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
local_num_experts = self.global_expert_num // self.world_size
expert_map_all = torch.full(
(self.num_moe_layers, self.world_size, self.global_expert_num),
-1,
dtype=torch.int32)
for r in range(self.world_size):
if r < self.world_size - 1:
start = r * local_num_experts
end = (r + 1) * local_num_experts
local_count = local_num_experts
else:
start = r * local_num_experts
end = self.global_expert_num
local_count = self.global_expert_num - r * local_num_experts
if r < self.init_redundancy_expert:
local_count += 1
if end < self.global_expert_num:
end += 1
else:
start -= 1
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
self.num_moe_layers, -1)
return expert_map_all

View File

View File

@@ -0,0 +1,137 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from enum import Enum
import torch.distributed as dist
from vllm.logger import logger
class ExpertWeightUpdateState(Enum):
WAITING = 0 # waiting for updated expert_map by EplbWorker
READY = 1 # ready for d2d expert weights updating
TRANSFERRING = 2 # d2d finished and waiting for updating expert_map into model
class D2DExpertWeightLoader:
def __init__(self):
self.comm_op_list = None
self.updated_expert_map = None
self.updated_log2phy_map = None
self.layer_id = -1 # layer id to be updated
self.state = ExpertWeightUpdateState.WAITING
self.recv_expert_list = []
self.mock_flag = True
def set_adator(self, eplb_adaptor):
self.eplb_adaptor = eplb_adaptor
def generate_expert_d2d_transfer_task(self, expert_send_info,
expert_recv_info, updated_expert_map,
layer_id):
# When current send/recv and weight.expert_map update tasks are not finished, cannot accept new d2d task
if self.state != ExpertWeightUpdateState.WAITING:
logger.error(
"current d2d weight update tasks are on-going, cannot accept new weight update task"
)
return
# If neither send nor receive task is needed for this layer on this rank, return
if not (expert_send_info or expert_recv_info):
return
self.updated_expert_map = updated_expert_map
self.layer_id = layer_id
self.comm_op_list = []
for send_info in expert_send_info:
dst_rank, global_expert_id_to_send = send_info
local_expert_id = self.eplb_adaptor.expert_map_per_layer_cpu[
layer_id][global_expert_id_to_send].item()
for src_tensor in self.eplb_adaptor.expert_param_per_layer[
layer_id][local_expert_id]:
self.comm_op_list.append(
dist.P2POp(dist.isend, src_tensor, dst_rank))
buffer_tensor_id = 0
for recv_info in expert_recv_info:
recv_rank, global_expert_id_to_recv = recv_info
for buffer_tensor in self.eplb_adaptor.buffer_tensor_list[
buffer_tensor_id]:
self.comm_op_list.append(
dist.P2POp(dist.irecv, buffer_tensor, recv_rank))
local_expert_to_replace = self.updated_expert_map[
global_expert_id_to_recv].item()
self.recv_expert_list.append(
(local_expert_to_replace, buffer_tensor_id))
buffer_tensor_id += 1
self.state = ExpertWeightUpdateState.READY
def set_log2phy_map(self, log2phy_map):
self.updated_log2phy_map = log2phy_map
def asyn_expert_weight_transfer(self, reqs):
# Only when send/recv tasks are parsed into self.comm_op_list, d2d send/recv tasks can be luanched
if self.state != ExpertWeightUpdateState.READY:
return
# set asynchronous stream for d2d expert weight transfer
if self.comm_op_list:
ret_list = dist.batch_isend_irecv(self.comm_op_list)
reqs.extend(ret_list)
self.state = ExpertWeightUpdateState.TRANSFERRING
def update_expert_map_and_weight(self, reqs):
# Only after send/recv tasks have been luanched, expert_map and weight can be updated
if self.state != ExpertWeightUpdateState.TRANSFERRING:
return
# Waiting for send/recv tasks finish
for req in reqs:
req.wait()
if self.comm_op_list is not None:
self.comm_op_list = None
# update expert_map
self.eplb_adaptor.do_update_expert_map(self.layer_id,
self.updated_expert_map)
# update log2phy_map
self.eplb_adaptor.do_update_log2phy_map(self.layer_id,
self.updated_log2phy_map)
# update expert weight
buffer_tensor_id = 0
for recv_expert_info in self.recv_expert_list:
local_expert_to_replace, buffer_tensor_id = recv_expert_info
self.eplb_adaptor.do_update_expert_weight(self.layer_id,
local_expert_to_replace,
buffer_tensor_id)
logger.info(
f"[EPLB] finished update expert weight for layer: {self.layer_id}")
self.recv_expert_list = []
self.updated_expert_map = None
self.layer_id = -1
self.state = ExpertWeightUpdateState.WAITING
def load_impl(self, old_expert_table, new_expert_table):
raise NotImplementedError

View File

@@ -0,0 +1,135 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
import random
import torch
from vllm.logger import logger
def determine_default_expert_map(global_expert_num, world_size, rank_id,
global_redundant_expert_num):
if world_size == 1:
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
return (global_expert_num, local_ids)
local_num_experts = global_expert_num // world_size
expert_map = torch.full((global_expert_num, ), -1, dtype=torch.int32)
if rank_id < world_size - 1:
start = rank_id * local_num_experts
end = (rank_id + 1) * local_num_experts
local_count = local_num_experts
else:
start = rank_id * local_num_experts
end = global_expert_num
local_count = global_expert_num - rank_id * local_num_experts
if isinstance(global_redundant_expert_num,
int) and rank_id < global_redundant_expert_num:
local_count += 1
if end < global_expert_num:
end += 1
else:
start -= 1
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map[start:end] = local_ids
return (local_count, expert_map)
def generate_log2phy_map(expert_map):
num_local_experts = expert_map.max() + 1
log2phy_map = expert_map.clone()
num_ranks, num_global_expert = log2phy_map.shape
row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
num_global_expert) * num_local_experts
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
for idx in range(num_global_expert):
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
num_rank_holding_expert = positive_rank_idx.size(0)
if num_rank_holding_expert == 0:
log2phy_map[:, idx] = torch.full((num_ranks, ),
0,
dtype=log2phy_map.dtype)
if num_rank_holding_expert == 1:
log2phy_map[negative_rank_idx, idx] = torch.full(
(num_ranks - 1, ),
log2phy_map[positive_rank_idx, idx].item(),
dtype=log2phy_map.dtype)
else:
try:
random_list = [
random.choice(log2phy_map[positive_rank_idx, idx])
for _ in range(num_ranks - num_rank_holding_expert)
]
log2phy_map[negative_rank_idx,
idx] = torch.tensor(random_list,
dtype=log2phy_map.dtype)
except Exception as e:
logger.error(f"Fail to get log2phy_map: {str(e)}")
return log2phy_map
def determine_default_log2phy_map(global_expert_num, world_size, rank_id,
global_redundant_expert_num):
if world_size == 1:
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
log2phy_map_all = generate_log2phy_map(expert_map_all)
return log2phy_map_all[rank_id]
local_num_experts = global_expert_num // world_size
expert_map_all = torch.full((world_size, global_expert_num),
-1,
dtype=torch.int32)
for r in range(world_size):
if r < world_size - 1:
start = r * local_num_experts
end = (r + 1) * local_num_experts
local_count = local_num_experts
else:
start = r * local_num_experts
end = global_expert_num
local_count = global_expert_num - r * local_num_experts
if isinstance(global_redundant_expert_num,
int) and rank_id < global_redundant_expert_num:
local_count += 1
if end < global_expert_num:
end += 1
else:
start -= 1
if isinstance(local_count, int):
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[r, start:end] = local_ids
log2phy_map_all = generate_log2phy_map(expert_map_all)
return log2phy_map_all[rank_id]

View File

@@ -0,0 +1,436 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from multiprocessing import Process, Queue
from typing import Any
import networkx as nx # type: ignore
import numpy as np
import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_ascend.eplb.core.eplb_utils import generate_log2phy_map
from vllm_ascend.eplb.core.policy.policy_factory import (DynamicConfig,
PolicyFactory)
class EplbWorker:
def __init__(self, shared_dict, policy_type, enable_d2d: bool = True):
self.policy_type = policy_type
self.policy = PolicyFactory.generate_policy(policy_type,
DynamicConfig())
self.shared_dict = shared_dict
self.old_expert_maps = None
self.enable_d2d = enable_d2d
self.rank_id = dist.get_rank()
def do_update(self):
# put data in to queue
# in process self.policy.generate_policy()
# get epxert table && tensor
# async stream
# D2D
# H2D
# Get initial expert_map
torch.set_num_threads(1)
if self.old_expert_maps is None:
self.old_expert_maps = self.get_init_expert_maps()
if self.old_expert_maps is not None:
self.num_local_experts = self.old_expert_maps.max() + 1
else:
raise ValueError("Failed to get expert_maps from shared_dict.")
# Get MOE load information
load_info = self.fetch_and_sum_load_info()
if load_info is None:
return
# Get the updated expert table based on the workload information
old_placement = self.global2local(self.old_expert_maps,
self.num_local_experts)
_, _, new_placement = self.calculate_rebalance_experts(
load_info, old_placement)
if not torch.is_tensor(new_placement):
new_placement = torch.tensor(new_placement)
self.check_expert_placement(old_placement, new_placement)
new_expert_maps = self.local2global(new_placement)
self.update_expert_map(new_expert_maps)
update_info = self.compose_expert_update_info_greedy(
new_expert_maps, self.old_expert_maps)
self.old_expert_maps = new_expert_maps
logger.info("EPLB Process compute complete")
packed_update_info = self.pack_update_info(update_info)
return packed_update_info
def check_expert_placement(self, old_placement, new_placement):
num_layers = old_placement.shape[0]
num_ranks = old_placement.shape[1]
for layer_id in range(num_layers):
# check if any logical expert is not placed on any rank
if torch.unique(new_placement[layer_id]).numel() < torch.unique(
old_placement[layer_id]).numel():
logger.error(
f"There exists expert not placed on any rank in layer {layer_id}"
)
new_placement[layer_id] = old_placement[layer_id]
continue
for rank_id in range(num_ranks):
new_placement_check = new_placement[layer_id][rank_id]
old_placement_check = old_placement[layer_id][rank_id]
# check if same logical experts are placed on the same NPU
if new_placement_check.numel() != torch.unique(
new_placement_check).numel():
logger.error(
f"Replicated experts are placed on the same NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
)
new_placement[layer_id] = old_placement[layer_id]
break
# check if there is any experts movement inside one NPU
expert_not_move = torch.isin(new_placement_check,
old_placement_check)
if not torch.equal(new_placement_check[expert_not_move],
old_placement_check[expert_not_move]):
logger.error(
f"There exists expert movement inside NPU, expert placement on layer {layer_id}, rank {rank_id} is invalid"
)
new_placement[layer_id] = old_placement[layer_id]
break
def compose_expert_update_info_bipartite(self, updated_expert_maps_org,
current_expert_maps_org):
# transform numpy array to torch tensor
updated_expert_maps = updated_expert_maps_org.clone()
current_expert_maps = current_expert_maps_org.clone()
updated_expert_maps = np.array(updated_expert_maps)
current_expert_maps = np.array(current_expert_maps)
num_layers = current_expert_maps.shape[0]
for layer_id in range(num_layers):
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
current_expert_maps_this_layer = current_expert_maps[layer_id]
updated_expert_maps_this_layer_org = updated_expert_maps_org[
layer_id]
from typing import Any
expert_send_info_this_layer: dict[Any, Any] = {}
expert_recv_info_this_layer: dict[Any, Any] = {}
# Guard Clause: if there is no expert weight update, avoid subsequent processing
if (np.equal(updated_expert_maps_this_layer,
current_expert_maps_this_layer)).all():
yield (expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer_org, layer_id)
# Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = np.where(
(current_expert_maps_this_layer == -1)
& (updated_expert_maps_this_layer != -1))
# record src ranks for potential transfer
src_ranks_set = dict()
for idx in range(len(dst_rank_indices)):
expert_id = experts_to_recv[idx].item()
if expert_id not in src_ranks_set:
src_ranks_set[expert_id] = np.where(
current_expert_maps_this_layer[:, expert_id] != -1)[0]
# loop until all experts are scheduled
while len(dst_rank_indices) > 0:
# construct bipartite graph
graph_expert_update: nx.Graph = nx.Graph()
for idx in range(len(dst_rank_indices)):
dst_rank_id = dst_rank_indices[idx].item()
expert_id = experts_to_recv[idx].item()
# add src ranks
src_rank_ids = src_ranks_set[expert_id]
graph_expert_update.add_nodes_from(src_rank_ids,
bipartite=0)
# add dest rank
graph_expert_update.add_node(str(dst_rank_id), bipartite=1)
# add edges
for src_rank_id in src_rank_ids:
graph_expert_update.add_edge(src_rank_id,
str(dst_rank_id))
# graph may not be connected
connected_components = list(
nx.connected_components(graph_expert_update))
all_matches = {}
# matching in this loop
for i, component in enumerate(connected_components):
subgraph = graph_expert_update.subgraph(component)
component_matching = nx.bipartite.maximum_matching(
subgraph)
all_matches.update(component_matching)
for src_rank, dst_rank in all_matches.items():
dst_rank = int(dst_rank)
assert src_rank != dst_rank
if graph_expert_update.nodes[src_rank]['bipartite'] == 0:
# currently not scheduled experts in rank dst_rank
experts_v = experts_to_recv[np.where(
dst_rank_indices == dst_rank)]
# src: src_rank, dest: dst_rank, expert: expert_id
expert_id = np.intersect1d(
experts_v,
np.where(current_expert_maps_this_layer[src_rank]
!= -1))[0]
# record send/rcv pairs
if src_rank not in expert_send_info_this_layer:
expert_send_info_this_layer[src_rank] = []
if dst_rank not in expert_recv_info_this_layer:
expert_recv_info_this_layer[dst_rank] = []
expert_send_info_this_layer[src_rank].append(
(dst_rank, expert_id))
expert_recv_info_this_layer[dst_rank].append(
(src_rank, expert_id))
remove_index = np.where(
np.logical_and(dst_rank_indices == dst_rank,
experts_to_recv == expert_id))
# update
dst_rank_indices = np.delete(dst_rank_indices,
remove_index)
experts_to_recv = np.delete(experts_to_recv,
remove_index)
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
updated_expert_maps_this_layer_org, layer_id)
# TODO: Here only expert weight exchange is considered, need to be extended to cover other weight update cases
def compose_expert_update_info_greedy(self, updated_expert_maps,
current_expert_maps):
num_layers = current_expert_maps.shape[0]
for layer_id in range(num_layers):
updated_expert_maps_this_layer = updated_expert_maps[layer_id]
current_expert_maps_this_layer = current_expert_maps[layer_id]
expert_send_info_this_layer: dict[Any, Any] = {}
expert_recv_info_this_layer: dict[Any, Any] = {}
# Guard Clause: if there is no expert weight update, avoid subsequent processing
if torch.equal(updated_expert_maps_this_layer,
current_expert_maps_this_layer):
yield (expert_send_info_this_layer,
expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id)
# Parse expert_ids each rank needs to receive from other ranks
dst_rank_indices, experts_to_recv = torch.where((current_expert_maps_this_layer == -1) \
& (updated_expert_maps_this_layer != -1))
# Parse expert_ids each rank needs to send to other ranks
src_rank_indices, experts_to_send = torch.where((current_expert_maps_this_layer != -1) \
& (updated_expert_maps_this_layer == -1))
for idx in range(len(dst_rank_indices)):
dst_rank_id = dst_rank_indices[idx].item()
expert_id = experts_to_recv[idx].item()
if dst_rank_id not in expert_recv_info_this_layer:
expert_recv_info_this_layer[dst_rank_id] = []
if not torch.isin(torch.tensor(expert_id),
experts_to_send).any():
# if expert_id are not sent out from any npu, it will be copied from one npu holding this expert
candidate_src_rank_indices = torch.where(
current_expert_maps_this_layer[:, expert_id] != -1)[0]
else:
candidate_src_rank_indices = src_rank_indices[
experts_to_send == expert_id]
# TODO: improve selection criterion of npu sending expert_id considering such as intra-node or inter-node...
src_rank_id = candidate_src_rank_indices[0].item()
if src_rank_id not in expert_send_info_this_layer:
expert_send_info_this_layer[src_rank_id] = []
expert_send_info_this_layer[src_rank_id].append(
(dst_rank_id, expert_id))
expert_recv_info_this_layer[dst_rank_id].append(
(src_rank_id, expert_id))
yield (expert_send_info_this_layer, expert_recv_info_this_layer,
updated_expert_maps_this_layer, layer_id)
def calculate_rebalance_experts(self, load_info, old_placement):
"""
Compute `new_map` by calling the `rebalance_experts` method of the policy instance.
"""
if self.old_expert_maps is None:
return False, None, None
changed, priority, new_map = self.policy.rebalance_experts(
old_placement, load_info)
return changed, priority, new_map
def get_init_expert_maps(self):
"""
Read the initial expert_map from shared_dict.
"""
return self.shared_dict.get("expert_maps", None)
def fetch_and_sum_load_info(self):
"""
Each time the subprocess is awakened, read the latest moe_load
(shape: [num_moe_layers, num_experts_per_layer]) from shared_dict.
"""
return self.shared_dict.get("moe_load", None)
def update_expert_map(self, expert_maps):
self.shared_dict["expert_maps"] = expert_maps
def global2local(self, placement: torch.Tensor,
E_local: int) -> tuple[torch.Tensor, torch.Tensor]:
L, G, _ = placement.shape
device = placement.device
pt_local = torch.full((L, G, E_local),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
slot_idx = placement[l_idx, g_idx, k_idx]
pt_local[l_idx, g_idx, slot_idx] = k_idx
return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape
device = placement_local.device
max_id = torch.max(placement_local)
E_global = (max_id + 1).item() if max_id >= 0 else 0
if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
gid_idx = placement_local[l_idx, g_idx, slot_idx]
placement_global[l_idx, g_idx, gid_idx] = slot_idx
return placement_global
def pack_update_info(self, update_info_generator):
"""
Pack a list of update info tuples for efficient IPC.
"""
send_all = []
recv_all = []
maps = []
log2phy_all = []
layer_ids = []
for send_info, recv_info, new_expert_map, layer_id in update_info_generator:
send_info_this_rank = send_info[
self.rank_id] if self.rank_id in send_info else []
recv_info_this_rank = recv_info[
self.rank_id] if self.rank_id in recv_info else []
send_all.append(send_info_this_rank)
recv_all.append(recv_info_this_rank)
maps.append(new_expert_map[self.rank_id].numpy().tolist())
log2phy_map = generate_log2phy_map(new_expert_map)
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
layer_ids.append(layer_id)
return list(zip(send_all, recv_all, maps, log2phy_all, layer_ids))
class EplbProcess:
def __init__(self,
shared_dict,
policy_type: int = 0,
enable_d2d: bool = True):
"""
Args:
shared_dict: Cross-process shared dict returned by Manager().dict()
policy_type: Integer passed to PolicyFactory.generate_policy
enable_d2d: Whether to enable D2D loading
"""
self.shared_dict = shared_dict
self.policy_type = policy_type
self.enable_d2d = enable_d2d
self.planner_q: Queue[Any] = Queue()
self.block_update_q: Queue[Any] = Queue(maxsize=1)
# Create EplbWorker instance
self.worker = EplbWorker(self.shared_dict, self.policy_type,
self.enable_d2d)
def worker_process(self, planner_q, block_update_q):
"""
Subprocess entry: bind to specified NPU, loop waiting for planner_q to wake up, call do_update, then notify main process update is complete.
"""
while True:
try:
planner_q.get()
packed_update_info = self.worker.do_update()
while True:
if not block_update_q.empty():
continue
block_update_q.put(packed_update_info)
break
except Exception as e:
logger.warning(f"[EPLB subprocess Exiting due to error: {e}",
exc_info=True)
break
def _launch_process(self):
"""
Use spawn method to launch subprocess and return (planner_q, block_update_q, proc).
"""
proc = Process(target=self.worker_process,
args=(self.planner_q, self.block_update_q),
daemon=True)
proc.start()
return proc

View File

View File

@@ -0,0 +1,42 @@
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from abc import abstractmethod
class DynamicConfig:
placement_policy = None
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
num_die_per_host = 8 # Number of dies on each host machine
class EplbPolicy:
def __init__(self, config: DynamicConfig):
self.config = config
@abstractmethod
def rebalance_experts(self, current_expert_table, expert_workload):
"""
Pass in the weights and return expert replication and placement under relevant constraints.
INPUT:
current_expert_table: [layerId, rankId, expert_num_i]
expert_workload = expert_table[layer0][rankId][expert_num_i]
RETURNED: (res, expert_table)
res:
1 -- table_changed
0 -- not_changed
expert_table: [layerId, rankId, expert_num_i]
expert_num_i --- [0, MaxExpertPerRank]
expertID = expert_table[layer0][rankId][expert_num_i]
array_values:
[0, 1, 2, 3, 248]
[4, 5, 6, 7, 254]
[8, 9, 10, 11, 71]
...
[252, 253, 254, 255, 0]
"""
pass

View File

@@ -0,0 +1,389 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from collections import defaultdict
from typing import cast
import numpy as np
from .policy_abstract import DynamicConfig, EplbPolicy
class DynamicTable:
# workload_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
workload_table = None
# placement_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
placement_table = None
class DynamicEplb(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
@staticmethod
def add_redundant(current_expert_table, expert_workload,
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num):
workload_dict: dict[int, int] = defaultdict(int)
placement_layer = current_expert_table[layer_idx].copy()
workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new
@staticmethod
# Split hot (high-load) experts into redundant experts
def original_compute_balanced_pack_redundancy(origin_weights, card_num,
num_redundancy_expert):
# Step 1: Sort the items by weight in descending order (we are sorting by weight now)
# Sort based on the second element (the second value of each tuple)
route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * (
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / (
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight)
origin_weights = weights
# Step 2: Calculate the number of items per box
expert_num = route_expert_num + num_redundancy_expert
items_per_box = expert_num // card_num # Number of items per box
remaining_items = expert_num % card_num # Number of items per box
# Step 3: Initialize card_num boxes with empty lists to store item IDs
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num # To store the total weight of each box
box_counts = [0] * card_num # To store the number of items in each box
index = 0
for i in range(route_expert_num):
redundancy_num = len(route_expert_redundancy[i])
for _ in range(redundancy_num):
cur_weight = 0
for item, weight in origin_weights:
if item == i:
cur_weight = weight
boxes[index].append(i)
boxes_weights[index].append(cur_weight)
box_weights[index] += cur_weight
box_counts[index] += 1
index += 1
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
origin_weights = [origin_weights[idx] for idx in sorted_indices]
# Step 4: Distribute items into boxes based on weight
for item_id, weight in origin_weights:
# Find the box with the least items but not full
min_box_index = -1
for i in range(card_num):
if item_id in boxes[i]:
continue
# Only choose boxes that still have space (box_counts[i] < items_per_box)
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i
# Place the item (id) into the selected box
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
# If there's an imbalance in the remaining items, reduce the "remaining_items" counter
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
# Step 5: Output each box's contents and total weight
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i], # List of item IDs in the box
"weight": boxes_weights[i],
"total_weight": box_weights[i], # Total weight in this box
"item_count": box_counts[i] # Number of items in the box
})
return result, boxes
# Split hot (high-load) experts into redundant experts
@staticmethod
def compute_balanced_pack_redundancy(origin_weights, card_num,
num_redundancy_expert):
route_expert_num = len(origin_weights)
route_expert_redundancy: list[list[int]] = [
[] for _ in range(route_expert_num)
]
for i in range(num_redundancy_expert):
sorted_indices = np.argsort([t[1] for t in origin_weights],
kind='stable')[::-1]
weights = [origin_weights[idx] for idx in sorted_indices]
tmp_raw_weight = weights[0][1] * (
len(route_expert_redundancy[weights[0][0]]) + 1)
route_expert_redundancy[weights[0][0]].append(route_expert_num + i)
avg_weight = tmp_raw_weight / (
len(route_expert_redundancy[weights[0][0]]) + 1)
weights[0] = (weights[0][0], avg_weight)
origin_weights = weights
expert_num = route_expert_num + num_redundancy_expert
if card_num == 0:
raise RuntimeError("card_num can not be 0.")
items_per_box = expert_num // card_num
remaining_items = expert_num % card_num
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num
box_counts = [0] * card_num
all_weights = np.zeros((expert_num, ), dtype='object')
all_weights[:route_expert_num] = origin_weights
index = route_expert_num
for i in range(route_expert_num):
redundancy_num = len(route_expert_redundancy[i])
for _ in range(redundancy_num):
for item, weight in origin_weights:
if item == i:
all_weights[index] = (item, weight)
index += 1
sorted_indices = np.argsort([t[1] for t in all_weights],
kind='stable')[::-1]
all_weights = [all_weights[idx] for idx in sorted_indices]
for item_id, weight in all_weights:
min_box_index = -1
for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
if item_id not in boxes[i]:
min_box_index = i
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i],
"weight": boxes_weights[i],
"total_weight": box_weights[i],
"item_count": box_counts[i]
})
return result, boxes
# Scheme without redundant experts
@staticmethod
def compute_balanced_pack(origin_weights, card_num):
sorted_indices = np.argsort([t[1] for t in origin_weights])[::-1]
weights = origin_weights[sorted_indices]
expert_num = len(weights)
if card_num == 0:
raise RuntimeError("card_num can not be 0.")
items_per_box = expert_num // card_num
remaining_items = expert_num % card_num
boxes: list[list[int]] = [[] for _ in range(card_num)]
boxes_weights: list[list[float]] = [[] for _ in range(card_num)]
box_weights = [0] * card_num
box_counts = [0] * card_num
for item_id, weight in weights:
min_box_index = -1
for i in range(card_num):
if box_counts[i] < items_per_box or (box_counts[i]
== items_per_box
and remaining_items > 0):
if min_box_index == -1 or box_weights[i] < box_weights[
min_box_index]:
min_box_index = i
boxes[min_box_index].append(item_id)
boxes_weights[min_box_index].append(weight)
box_weights[min_box_index] += weight
box_counts[min_box_index] += 1
if box_counts[min_box_index] == (items_per_box +
1) and remaining_items > 0:
remaining_items -= 1
result = []
for i in range(card_num):
result.append({
"box_index": i + 1,
"items": boxes[i],
"weight": boxes_weights[i],
"total_weight": box_weights[i],
"item_count": box_counts[i]
})
return result, boxes
@staticmethod
def get_redundant_num(npu_num, counts):
redundant_num_each_npu: int = np.sum(counts - 1)
return redundant_num_each_npu
@staticmethod
def calculate_max_heat_per_layer(workload_table, layer_num):
max_heat_per_layer: list[float] = []
for layer_idx in range(layer_num):
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
max_heat_per_layer.append(np.max(npu_heats_now))
return max_heat_per_layer
@staticmethod
def constraint_expert_local_exchange(current_expert_table,
global_deployment):
for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])):
current_list = [
int(x) for x in current_expert_table[layer_id][card_id]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list)
new_index = [-1] * num
new_result = [-1] * num
remaining_elements = []
for i in range(num):
flag = True
for j in range(num):
if new_list[i] == current_list[j] and new_index[
j] == -1:
new_index[j] = 0
new_result[j] = current_list[j]
flag = False
break
if flag:
remaining_elements.append(new_list[i])
index = 0
for k in range(num):
if new_result[k] == -1:
new_result[k] = remaining_elements[index]
index += 1
global_deployment[layer_id][card_id] = new_result
return global_deployment
def rebalance_experts(self, current_expert_table, expert_workload):
info = DynamicTable()
info.workload_table = np.array(expert_workload)
info.placement_table = np.array(current_expert_table)
assert info.workload_table is not None
layer_num, num_npus, experts_per_npu = info.workload_table.shape
assert info.placement_table is not None
row = cast(np.ndarray, info.placement_table[0])
expert_ids, counts = np.unique(row, return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table,
info.workload_table,
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before)
# Perform load balancing and deploy redundant experts
layer_num = layer_workloads.shape[0]
expert_num = layer_workloads.shape[1]
# Validate that the number of experts, number of cards, and number of redundant experts do not exceed the number of cards
if num_original_expert != expert_num:
raise ValueError(
f"the number of original experts {num_original_expert} must be equal to expert_num {expert_num}"
)
if num_npus <= 0:
raise ValueError("the number of NPUs must be greater than 0")
if num_npus < num_redundancy_expert:
raise ValueError(
f"the number of NPUs {num_npus} must be greater than or equal to the number of redundant experts {num_redundancy_expert}"
)
# Number of experts deployed on each card includes one redundant expert
global_deployment: list[list[list[int]]] = [[[]
for _ in range(num_npus)]
for _ in range(layer_num)]
# Iterate to obtain the placement strategy for each layer, taking computational balance into account
max_heat_per_layer_after = np.zeros([layer_num])
for layer in range(layer_num):
# Get the expert IDs and their corresponding workloads for the current layer;
# workloads need to be normalized, and one redundant expert is added per card
weights = np.zeros((expert_num, ), dtype='object')
for expert_id, workload_weight in enumerate(
layer_workloads[layer]):
weights[expert_id] = (expert_id, workload_weight)
# Obtain the globally balanced placement strategy for each layer
result, layer_deployment = self.original_compute_balanced_pack_redundancy(
weights, num_npus, num_redundancy_expert)
global_deployment[layer] = layer_deployment
max_heat_per_layer_after[layer] = max(
result, key=lambda x: x['total_weight'])['total_weight']
new_global_deployment = self.constraint_expert_local_exchange(
current_expert_table, global_deployment)
# Obtain the priority of each layer
layer_changed_ratio = []
for layer_idx in range(layer_num):
layer_changed_ratio.append(max_heat_per_layer_after[layer_idx] /
max_heat_per_layer_before[layer_idx])
per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after)
change = 0
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1
return change, per_layer_priority, np.array(
new_global_deployment).tolist()

View File

@@ -0,0 +1,771 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
from abc import abstractmethod
from collections import defaultdict
import numpy as np
class DynamicConfig:
placement_policy = None
max_transferred_expert_per_layer = 100 # Maximum number of experts that can be migrated per layer on a single host
ep_worldsize = 64 # Total number of dies across the entire cluster where experts are distributed
num_die_per_host = 8 # Number of dies on each host machine
class EplbPolicy:
def __init__(self, config: DynamicConfig):
self.config = config
@abstractmethod
def rebalance_experts(self, current_expert_table, expert_workload):
"""
Pass in the weights and return expert replication and placement under relevant constraints.
INPUT:
current_expert_table: [layerId, rankId, expert_num_i]
expert_workload = expert_table[layer0][rankId][expert_num_i]
RETURNED: (res, expert_table)
res:
1 -- table_changed
0 -- not_changed
expert_table: [layerId, rankId, expert_num_i]
expert_num_i --- [0, MaxExpertPerRank]
expertID = expert_table[layer0][rankId][expert_num_i]
array_values:
[0, 1, 2, 3, 248]
[4, 5, 6, 7, 254]
[8, 9, 10, 11, 71]
...
[252, 253, 254, 255, 0]
"""
pass
class DynamicTable:
# workload_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: workload (heat) at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the workload (heat) of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
workload_table = None
# placement_table:
# 3D matrix: [layer, gpus, experts_per_gpu_per_layer] -> value: physical expert ID at the corresponding position
# Size: number of layers * number of GPUs * number of experts per GPU per layer
# The element at (i, j, k) represents the physical expert ID of the k-th expert on the j-th GPU in the i-th layer
# For experts that are not available or collected, the value is set to -1
placement_table = None
class DynamicEplbV2(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
@staticmethod
def safe_divide(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a / b
@staticmethod
def safe_exact_divide(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a // b
@staticmethod
def safe_mod(a, b):
if b == 0:
print("Division by zero is not allowed")
return 0
return a % b
@staticmethod
def add_redundant(current_expert_table, expert_workload,
num_original_expert):
layer_num, npu_num, experts_per_npu = expert_workload.shape
workload_new = np.zeros((layer_num, num_original_expert))
for layer_idx in range(layer_num):
workload_dict: dict[int, int] = defaultdict(int)
placement_layer = current_expert_table[layer_idx].copy()
workload_layer = expert_workload[layer_idx].copy()
for npu_idx in range(npu_num):
for expert_idx in range(experts_per_npu):
workload_dict[placement_layer[npu_idx][
expert_idx]] += workload_layer[npu_idx][expert_idx]
for expert_idx in range(num_original_expert):
workload_new[layer_idx][expert_idx] = workload_dict[expert_idx]
return workload_new
@staticmethod
def get_redundant_num(npu_num, counts):
redundant_num_each_npu: int = int(np.sum(counts - 1))
return redundant_num_each_npu
@staticmethod
def calculate_max_heat_per_layer(workload_table, layer_num):
max_heat_per_layer: list[float] = []
for layer_idx in range(layer_num):
npu_heats_now = np.sum(workload_table[layer_idx], axis=1)
max_heat_per_layer.append(np.max(npu_heats_now))
return max_heat_per_layer
def calculate_initial_imbalance(self, global_deployment,
new_layer_workloads):
device_num = global_deployment.shape[1]
layer_imbalance = []
expert_num = np.zeros_like(new_layer_workloads)
for layer_id, layer in enumerate(global_deployment):
for device in layer:
for expert_id in device:
expert_num[layer_id][expert_id] += 1
for layer_id, layer in enumerate(global_deployment):
cur_layer_max_workload = 0
total_workload = 0
for box in layer:
box_workload = 0
for expert_id in box:
update_workload = self.safe_divide(
new_layer_workloads[layer_id][expert_id],
expert_num[layer_id][expert_id])
box_workload += update_workload
total_workload += update_workload
if cur_layer_max_workload < box_workload:
cur_layer_max_workload = box_workload
cur_layer_imbalance = self.safe_divide(
cur_layer_max_workload,
(self.safe_divide(total_workload, device_num)))
layer_imbalance.append(cur_layer_imbalance)
return layer_imbalance
def compute_redundant_assignments(self, base_experts,
num_redundant_experts, num_experts):
redundant_assignments: list[list[int]] = [[]
for _ in range(num_experts)]
current_weights = base_experts.copy()
for i in range(num_redundant_experts):
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
target_expert = sorted_weights[0]
expert_id, original_weight = target_expert
current_redundancy = len(redundant_assignments[expert_id])
new_avg_weight = self.safe_divide(
original_weight * (current_redundancy + 1),
(current_redundancy + 2))
redundant_assignments[expert_id].append(num_experts + i)
current_weights[sorted_indices[0]] = (expert_id, new_avg_weight)
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
return redundant_assignments, sorted_weights
def repeat_compute_redundant_assignments(self, layer_workloads, rendun_pos,
num_experts, num_exist_expert,
device_assignments, device_counts,
expert_from_device,
com_between_devices):
current_weights = np.zeros((num_experts, ), dtype='object')
for expert_id, workload_weight in enumerate(layer_workloads):
current_weights[expert_id] = (expert_id, workload_weight)
devices_with_slots = []
for device_id, device_rendun_pos in enumerate(rendun_pos):
if len(device_rendun_pos) != 0:
devices_with_slots.append(device_id)
while devices_with_slots:
sorted_indices = np.argsort([w for _, w in current_weights],
kind='stable')[::-1]
sorted_weights = [current_weights[i] for i in sorted_indices]
for index, target_weight in enumerate(sorted_weights):
expert_id, original_weight = target_weight
if original_weight == -1:
print("Error:Redundant expert failure re-occurred")
redundancy_successful = True
break
redundancy_successful = False
for cur_device_id in devices_with_slots:
if expert_id not in device_assignments[cur_device_id]:
pos = rendun_pos[cur_device_id].pop()
if len(rendun_pos[cur_device_id]) == 0:
devices_with_slots = [
device_id for device_id in devices_with_slots
if device_id != cur_device_id
]
device_assignments[cur_device_id][pos] = expert_id
device_counts[cur_device_id] += 1
communication_box_index = expert_from_device[expert_id]
com_between_devices[cur_device_id][
communication_box_index] = expert_id
new_weight = self.safe_divide(
(original_weight * num_exist_expert[expert_id]),
(num_exist_expert[expert_id] + 1))
sorted_weights[index] = (expert_id, new_weight)
num_exist_expert[expert_id] += 1
redundancy_successful = True
break
if redundancy_successful:
break
sorted_indices = np.argsort([id for id, _ in sorted_weights],
kind='stable')
sorted_weights = [sorted_weights[i][1] for i in sorted_indices]
return sorted_weights, device_assignments, device_counts, com_between_devices
@staticmethod
def prepare_expert_list(base_experts, redundant_assignments,
num_redundant_experts):
redundant_expert_list = np.empty(num_redundant_experts, dtype=object)
index = 0
num_experts = len(redundant_assignments)
for expert_id in range(num_experts):
for _ in redundant_assignments[expert_id]:
redundant_expert_list[index] = (expert_id,
next(w
for eid, w in base_experts
if eid == expert_id))
index += 1
sorted_indices = np.argsort([w for _, w in redundant_expert_list],
kind='stable')[::-1]
return [redundant_expert_list[i] for i in sorted_indices]
@staticmethod
def non_redundant_expert_information(origin_deployment, updated_weights,
rendun_pos):
device_num = len(origin_deployment)
num_experts_per_device = origin_deployment.shape[1]
device_assignments = [[-1 for _ in range(num_experts_per_device)]
for _ in range(device_num)]
device_weights = [[0 for _ in range(num_experts_per_device)]
for _ in range(device_num)]
device_loads = [0] * device_num
device_counts = [0] * device_num
for device_id, device in enumerate(origin_deployment):
for index, expert_id in enumerate(device):
if index in rendun_pos[device_id]:
continue
device_assignments[device_id][index] = expert_id
cur_weight = next(
weight for expert_id_of_weight, weight in updated_weights
if expert_id_of_weight == expert_id)
device_weights[device_id][index] = cur_weight
device_loads[device_id] += cur_weight
device_counts[device_id] += 1
return device_assignments, device_weights, device_loads, device_counts
def recomputing_initial_weight(self, layer_workloads, device_assignments):
num_all_experts = [0] * len(layer_workloads)
for device in device_assignments:
for expert_id in device:
if expert_id != -1:
num_all_experts[expert_id] += 1
cur_layer_workload = []
for expert_id, weight in enumerate(layer_workloads):
if num_all_experts[expert_id] == 0:
cur_layer_workload.append(-1)
else:
cur_layer_workload.append(
self.safe_divide(weight, num_all_experts[expert_id]))
return cur_layer_workload, num_all_experts
def distribute_redun_experts(self, layer_workloads, device_assignments,
device_weights, device_loads, device_counts,
redundant_expert_list, expert_from_device,
num_experts, rendun_pos):
num_devices = len(device_assignments)
com_between_devices: list[dict[int,
int]] = [{} for _ in range(num_devices)]
for expert_id, weight in redundant_expert_list:
candidate = -1
for dev_id in range(num_devices):
if len(rendun_pos[dev_id]) == 0:
continue
if expert_id in device_assignments[dev_id]:
continue
if candidate == -1 or device_loads[dev_id] < device_loads[
candidate]:
candidate = dev_id
if candidate != -1:
pos = rendun_pos[candidate].pop()
device_assignments[candidate][pos] = expert_id
device_weights[candidate][pos] = weight
device_loads[candidate] += weight
device_counts[candidate] += 1
communication_box_index = expert_from_device[expert_id]
com_between_devices[candidate][
communication_box_index] = expert_id
if any(sublist for sublist in rendun_pos):
cur_layer_workload, num_exist_expert = self.recomputing_initial_weight(
layer_workloads, device_assignments)
update_workload, device_assignments, device_counts, com_between_devices = self.repeat_compute_redundant_assignments(
cur_layer_workload, rendun_pos, num_experts, num_exist_expert,
device_assignments, device_loads, expert_from_device,
com_between_devices)
device_loads = [0] * len(device_counts)
for device_id, device in enumerate(device_assignments):
for index, expert_id in enumerate(device):
device_weights[device_id][index] = update_workload[
expert_id]
device_loads[device_id] += update_workload[expert_id]
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
def redundancy_again(self, layer_workloads, origin_weights,
origin_deployment, expert_from_device, num_node,
is_node_redundant, rendun_pos):
num_experts = len(origin_weights)
if is_node_redundant:
num_experts = num_experts * num_node
num_redundant_experts = 0
for rank_empty_pos in rendun_pos:
num_redundant_experts += len(rank_empty_pos)
redundant_assignments, updated_weights = self.compute_redundant_assignments(
origin_weights, num_redundant_experts, num_experts)
redundant_expert_list = self.prepare_expert_list(
updated_weights, redundant_assignments, num_redundant_experts)
device_assignments, device_weights, device_loads, device_counts = self.non_redundant_expert_information(
origin_deployment, updated_weights, rendun_pos)
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.distribute_redun_experts(
layer_workloads, device_assignments, device_weights, device_loads,
device_counts, redundant_expert_list, expert_from_device,
num_experts, rendun_pos)
return device_assignments, device_weights, device_loads, device_counts, com_between_devices
@staticmethod
def generate_allocation_report(device_assignments, device_weights,
device_loads, device_counts):
report = []
max_load = 0.0
for dev_id in range(len(device_assignments)):
current_load = device_loads[dev_id]
max_load = max(max_load, current_load)
report.append({
"device_id": dev_id + 1,
"assigned_experts": device_assignments[dev_id],
"expert_weights": device_weights[dev_id],
"total_load": current_load,
"expert_count": device_counts[dev_id]
})
return report, max_load
@staticmethod
def exchange_expert(cur_exchange_index, next_exchange_index, cur_device_id,
next_device_id, cur_layer_result, com_between_devices):
cur_device_deployment = cur_layer_result[cur_device_id][
'assigned_experts']
next_device_deployment = cur_layer_result[next_device_id][
'assigned_experts']
cur_device_weight = cur_layer_result[cur_device_id]['expert_weights']
next_device_weight = cur_layer_result[next_device_id]['expert_weights']
cur_expert_id = cur_device_deployment[cur_exchange_index]
next_expert_id = next_device_deployment[next_exchange_index]
cur_device_deployment[cur_exchange_index] = next_expert_id
next_device_deployment[next_exchange_index] = cur_expert_id
cur_expert_weight = cur_device_weight[cur_exchange_index]
next_expert_weight = next_device_weight[next_exchange_index]
cur_device_weight[cur_exchange_index] = next_expert_weight
next_device_weight[next_exchange_index] = cur_expert_weight
cur_layer_result[cur_device_id][
'total_load'] += next_expert_weight - cur_expert_weight
cur_layer_result[next_device_id][
'total_load'] += cur_expert_weight - next_expert_weight
com_between_devices[cur_device_id][next_device_id] = next_expert_id
com_between_devices[next_device_id][cur_device_id] = cur_expert_id
def redundant_expert_deployment(self, layer_workloads, original_deployment,
expert_from_device, node_num,
is_node_redundant, rendun_pos):
device_num, per_device_expert_num = original_deployment.shape
route_expert_num = layer_workloads.shape[0]
per_node_device_num = self.safe_exact_divide(device_num, node_num)
per_node_route_expert_num = per_node_device_num * (
per_device_expert_num - 1)
weights = np.zeros((route_expert_num, ), dtype='object')
for expert_id, workload_weight in enumerate(layer_workloads):
weights[expert_id] = (expert_id, workload_weight)
if is_node_redundant:
device_assignments = []
device_weights = []
device_loads = []
device_counts = []
com_between_devices = []
for node_id in range(node_num):
cur_node_weights = weights[node_id *
per_node_route_expert_num:(node_id +
1) *
per_node_route_expert_num]
cur_original_deployment = original_deployment[
node_id * per_node_device_num:(node_id + 1) *
per_node_device_num]
cur_node_rendun_pos = rendun_pos[node_id *
per_node_device_num:(node_id +
1) *
per_node_device_num]
cur_device_assignments, cur_device_weights, cur_device_loads, cur_device_counts, cur_com_between_devices = self.redundancy_again(
layer_workloads, cur_node_weights, cur_original_deployment,
expert_from_device, node_num, is_node_redundant,
cur_node_rendun_pos)
device_assignments += cur_device_assignments
device_weights += cur_device_weights
device_loads += cur_device_loads
device_counts += cur_device_counts
com_between_devices += cur_com_between_devices
else:
device_assignments, device_weights, device_loads, device_counts, com_between_devices = self.redundancy_again(
layer_workloads, weights, original_deployment,
expert_from_device, node_num, is_node_redundant, rendun_pos)
report, max_load = self.generate_allocation_report(
device_assignments, device_weights, device_loads, device_counts)
return report, max_load, com_between_devices
@staticmethod
def two_device_exchange_experts(cur_device_result, exchange_device_result,
cur_exchanged_expert_id,
next_exchanged_expert_id, ave_workload,
increment, num_redundancy_expert):
cur_device_weight = cur_device_result['expert_weights']
next_device_weight = exchange_device_result['expert_weights']
cur_device_expert_id = cur_device_result['assigned_experts']
next_device_expert_id = exchange_device_result['assigned_experts']
cur_device_total_weight = cur_device_result['total_load']
next_device_total_weight = exchange_device_result['total_load']
max_weight = max(cur_device_total_weight, next_device_total_weight)
cur_exchange_index = -1
next_exchange_index = -1
for index, weight in enumerate(cur_device_weight):
for next_index, next_weight in enumerate(next_device_weight):
change_flag = True
if (cur_device_expert_id[index] in next_device_expert_id
or next_device_expert_id[next_index]
in cur_device_expert_id):
change_flag = False
if (cur_device_expert_id[index] not in cur_exchanged_expert_id
) and (next_device_expert_id[next_index]
not in next_exchanged_expert_id) and change_flag:
cur_total_weight_after_exchange = cur_device_total_weight - weight + next_weight
next_total_weight_after_exchange = next_device_total_weight - next_weight + weight
exchange_max_weight = max(
cur_total_weight_after_exchange,
next_total_weight_after_exchange)
if exchange_max_weight < max_weight and (
max_weight -
exchange_max_weight) >= (ave_workload * increment):
max_weight = exchange_max_weight
cur_exchange_index = index
next_exchange_index = next_index
return cur_exchange_index, next_exchange_index
def expert_exchange_between_devices(self,
ave_workload,
increment,
cur_layer_result,
com_between_devices,
num_redundancy_expert,
node_idx=0,
per_node_device_num=0,
is_node_redundant=False):
if is_node_redundant:
cur_devices_result = cur_layer_result[node_idx *
per_node_device_num:
(node_idx + 1) *
per_node_device_num]
else:
cur_devices_result = cur_layer_result
devices_total_weight = []
for device in cur_devices_result:
devices_total_weight.append(
(device['total_load'], device['device_id'] - 1))
exchange_frequency = 100
while exchange_frequency > 0:
exchange_frequency -= 1
devices_total_weight.sort(key=lambda x: x[0])
max_weight_device_id = devices_total_weight[-1][1]
exchange = False
for index in range(0, len(devices_total_weight) - 1):
min_weight_device_id = devices_total_weight[index][1]
if min_weight_device_id not in com_between_devices[
max_weight_device_id]:
cur_exchanged_expert_id = list(
com_between_devices[max_weight_device_id].values())
next_exchanged_expert_id = list(
com_between_devices[min_weight_device_id].values())
cur_exchange_index, next_exchange_index = self.two_device_exchange_experts(
cur_layer_result[max_weight_device_id],
cur_layer_result[min_weight_device_id],
cur_exchanged_expert_id, next_exchanged_expert_id,
ave_workload, increment, num_redundancy_expert)
if cur_exchange_index != -1:
self.exchange_expert(cur_exchange_index,
next_exchange_index,
max_weight_device_id,
min_weight_device_id,
cur_layer_result,
com_between_devices)
devices_total_weight[-1] = (
cur_layer_result[max_weight_device_id]
['total_load'], max_weight_device_id)
devices_total_weight[index] = (
cur_layer_result[min_weight_device_id]
['total_load'], min_weight_device_id)
exchange = True
break
if not exchange:
break
def exchange_experts(self, layer_result, layer_com_between_devices,
num_nodes, device_num, is_node_redundant,
ave_workload, increment, num_redundancy_expert,
org_deployment):
global_deployment = []
if is_node_redundant:
per_node_device_num = self.safe_exact_divide(device_num, num_nodes)
for node_idx in range(num_nodes):
self.expert_exchange_between_devices(
ave_workload, increment, layer_result,
layer_com_between_devices, num_redundancy_expert, node_idx,
per_node_device_num, is_node_redundant)
else:
self.expert_exchange_between_devices(ave_workload, increment,
layer_result,
layer_com_between_devices,
num_redundancy_expert)
max_workload = 0
for box in layer_result:
global_deployment.append(box['assigned_experts'])
if max_workload < box['total_load']:
max_workload = box['total_load']
global_deployment = np.array(global_deployment)
return global_deployment, max_workload
def count_elements(self, lst):
count = 0
for item in lst:
if isinstance(item, list):
count += self.count_elements(item)
else:
count += 1
return count
@staticmethod
def constraint_expert_local_exchange(current_expert_table,
global_deployment):
for layer_id in range(len(global_deployment)):
for card_id in range(len(global_deployment[layer_id])):
current_list = [
int(x) for x in current_expert_table[layer_id][card_id]
]
new_list = [
int(x) for x in global_deployment[layer_id][card_id]
]
num = len(new_list)
new_index = [-1] * num
new_result = [-1] * num
remaining_elements = []
for i in range(num):
flag = True
for j in range(num):
if new_list[i] == current_list[j] and new_index[
j] == -1:
new_index[j] = 0
new_result[j] = current_list[j]
flag = False
break
if flag:
remaining_elements.append(new_list[i])
index = 0
for k in range(num):
if new_result[k] == -1:
new_result[k] = remaining_elements[index]
index += 1
global_deployment[layer_id][card_id] = new_result
return global_deployment
def rebalance_experts(self,
current_expert_table,
expert_workload,
is_node_redundant=False,
increment=0.01):
info = DynamicTable()
info.workload_table = expert_workload.numpy()
info.placement_table = current_expert_table.numpy()
assert info.workload_table is not None
layer_num, num_npus, experts_per_npu = info.workload_table.shape
expert_ids, counts = np.unique(info.placement_table[0],
return_counts=True)
num_redundancy_expert = self.get_redundant_num(num_npus, counts)
num_original_expert = len(expert_ids)
layer_workloads = self.add_redundant(info.placement_table,
info.workload_table,
num_original_expert)
max_heat_per_layer_before = self.calculate_max_heat_per_layer(
info.workload_table, layer_num)
npu_heat_all_origin = sum(max_heat_per_layer_before)
num_node = self.safe_exact_divide(num_npus, 8)
layer_num = layer_workloads.shape[0]
expert_num = layer_workloads.shape[1]
expert_from_device = np.zeros((layer_num, num_original_expert))
if num_original_expert != expert_num:
raise ValueError(
f"The number of original experts ({num_original_expert}) must match expert_num ({expert_num})"
)
if num_npus <= 0:
raise ValueError("The number of NPUs must be greater than 0")
if num_npus < num_redundancy_expert:
raise ValueError(
f"The number of NPUs ({num_npus}) must be greater than or equal to the number of redundant experts ({num_redundancy_expert})"
)
global_deployment: list[list[list[int]]] = [[[]
for _ in range(num_npus)]
for _ in range(layer_num)]
layer_initial_imbalance = self.calculate_initial_imbalance(
info.placement_table, layer_workloads)
max_heat_per_layer_after = np.zeros([layer_num])
sum_num = 0
for layer in range(layer_num):
# print(f"Load imbalance ratio of layer {layer} under the new workload", layer_initial_imbalance[layer])
if layer_initial_imbalance[layer] < 1.01:
global_deployment[layer] = info.placement_table[layer]
continue
ave_workload = self.safe_divide(np.sum(layer_workloads[layer]),
num_npus)
rendun_pos: list[list[int]] = [[] for _ in range(num_npus)]
existing_experts = set()
for device_id, device in enumerate(info.placement_table[layer]):
for index, expert_id in enumerate(device):
if expert_id not in existing_experts:
existing_experts.add(expert_id)
expert_from_device[layer][expert_id] = device_id
else:
rendun_pos[device_id].append(index)
result, max_workload, com_between_devices = self.redundant_expert_deployment(
layer_workloads[layer], info.placement_table[layer],
expert_from_device[layer], num_node, is_node_redundant,
rendun_pos)
# print(layer, f"Imbalance Ratio after Redundancy Adjustment:", self.safe_divide(max_workload, ave_workload))
global_deployment[layer], new_max_workload = self.exchange_experts(
result, com_between_devices, num_node, num_npus,
is_node_redundant, ave_workload, increment,
num_redundancy_expert, info.placement_table[layer])
# print(layer, f"Imbalance Ratio after Swap Adjustment:", self.safe_divide(new_max_workload, ave_workload))
for device_id in range(num_npus):
com_between_devices[device_id] = {
key: value
for key, value in com_between_devices[device_id].items()
}
sum_num += self.count_elements(com_between_devices[device_id])
max_heat_per_layer_after[layer] = max(
result, key=lambda x: x['total_load'])['total_load']
layer_changed_ratio = []
for layer_idx in range(layer_num):
layer_changed_ratio.append(
self.safe_divide(max_heat_per_layer_after[layer_idx],
max_heat_per_layer_before[layer_idx]))
per_layer_priority = np.argsort(layer_changed_ratio)
npu_heat_all_after = sum(max_heat_per_layer_after)
change = 0
if npu_heat_all_after < 0.95 * npu_heat_all_origin:
change = 1
new_global_deployment = self.constraint_expert_local_exchange(
current_expert_table, global_deployment)
return change, per_layer_priority, np.array(
new_global_deployment).tolist()

View File

@@ -0,0 +1,33 @@
# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this factory.
from .policy_abstract import DynamicConfig, EplbPolicy
from .policy_dynamic_ep import DynamicEplb
from .policy_dynamic_ep_v2 import DynamicEplbV2
from .policy_flashlb import FlashLB
from .policy_random import RandomLoadBalance
class PolicyFactory:
@staticmethod
def generate_policy(policy_type: int, config: DynamicConfig) -> EplbPolicy:
policy = {
# Constraint applying Dynamic EPLB policy V2:
# If there exists redundant expert:
# only one redundant expert can be placed in one NPU and its physical expert index must be 0
# Applying greedy d2d expert weight update composing
0:
RandomLoadBalance, # RandomLoadBalance: shuffle last physical expert on NPU 1 and 3
1:
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
2:
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
3:
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
}
policy_class = policy.get(policy_type, RandomLoadBalance)
policy_instance = policy_class(config)
if policy_type == 3:
policy_instance.warm_up()
return policy_instance

View File

@@ -0,0 +1,651 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
import logging
from collections import deque
from typing import Dict
import numpy as np
import torch
from numba import njit # type: ignore
from .policy_abstract import DynamicConfig, EplbPolicy
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
@njit
def compute_piece_counts(X, P, stage_weights):
n_stage, N = X.shape
S = P - N
pieces = np.ones(N, dtype=np.int32)
unit = X / pieces # unit[i, j] = X[i, j] / pieces[j]
for _ in range(S):
deltas = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
# Find top1 and top2
idx1 = -1
idx2 = -1
val1 = -1.0
val2 = -1.0
for j in range(N):
v = unit[i, j]
if v > val1:
val2 = val1
idx2 = idx1
val1 = v
idx1 = j
elif v > val2:
val2 = v
idx2 = j
origin = unit[i, idx1]
secv = unit[i, idx2]
alt = X[i, idx1] / (pieces[idx1] + 1)
delta = origin - (alt if alt > secv else secv)
deltas[idx1] += delta * stage_weights[i] if np.any(
delta) != 0 else stage_weights[i]
max_idx = np.argmax(deltas)
pieces[max_idx] += 1
for i in range(n_stage):
unit[i, max_idx] = X[i, max_idx] / pieces[max_idx]
# Compute max load
max_load = 0.0
for j in range(N):
total = 0.0
for i in range(n_stage):
total += unit[i, j]
if total > max_load:
max_load = total
return pieces
@njit
def jsq_placement(X, pieces, M, stage_weights):
n_stage, N = X.shape
total_piece = pieces.sum()
num_per_group = total_piece // M
# 1. Compute unit_hotness
unit_hotness = np.empty((n_stage, N), dtype=np.float32)
for i in range(N):
if pieces[i] > 0:
for s in range(n_stage):
unit_hotness[s, i] = X[s, i] / pieces[i]
else:
for s in range(n_stage):
unit_hotness[s, i] = 0.0
# 2. Sort by total hotness
scores = np.zeros(N, dtype=np.float32)
for i in range(N):
for s in range(n_stage):
scores[i] += unit_hotness[s, i]
idx = np.argsort(-scores)
# 3. Initialization
loads = np.zeros((n_stage, M), dtype=np.float32)
dev_phy_exp_n = np.zeros(M, dtype=np.int32)
deployment = -np.ones((M, num_per_group), dtype=np.int32)
dep_ptr = np.zeros(M, dtype=np.int32)
# 4. Main loop
for t in range(N):
i = idx[t]
used_device = list()
for _ in range(pieces[i]):
# 4.1 Construct w vector
w = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
w[s] = unit_hotness[s, i]
# 4.2 Compute stage-level maximum load
stage_max = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
max_val = loads[s, 0]
for k in range(1, M):
if loads[s, k] > max_val:
max_val = loads[s, k]
stage_max[s] = max_val
# 4.3 Compute denominator
denom = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
sum_tmp = 0.0
for j in range(M):
sum_tmp += loads[s, j] + w[s]
denom[s] = sum_tmp / M + 1e-2
# 4.4 Find best device j
best_j = -1
best_val = 1e30
for j in range(M):
if dev_phy_exp_n[j] >= num_per_group:
continue
if j in used_device:
continue
score = 0.0
for s in range(n_stage):
tmp_sj = loads[s, j] + w[s]
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
score += stage_weights[s] * (numer_sj / denom[s])
if score < best_val:
best_val = score
best_j = j
if best_j == -1:
continue
used_device.append(best_j)
# 4.5 Update status
for s in range(n_stage):
loads[s, best_j] += w[s]
ptr = dep_ptr[best_j]
deployment[best_j, ptr] = i
dep_ptr[best_j] += 1
dev_phy_exp_n[best_j] += 1
# Handle remaining -1 values: fill with random elements from range(N) not in current column
for rank in range(M):
for col in range(num_per_group):
if deployment[rank, col] == -1:
# Get elements already in current column
current_rank_elements = set(deployment[rank, :])
# Filter elements from range(N) not in current column
available = [
x for x in range(N) if x not in current_rank_elements
]
# Randomly select an available element to fill
if len(available) > 0:
rand_idx = np.random.randint(0, len(available))
deployment[rank, col] = available[rand_idx]
elif N > 0:
# All unique experts are already in this rank's column, so we can pick any expert randomly.
deployment[rank, col] = np.random.randint(0, N)
return deployment
@njit
def slice_values(X, pieces):
total_len = 0
for i in range(X.shape[0]):
total_len += pieces[i]
result = np.empty(total_len, dtype=np.float32)
idx = 0
for i in range(X.shape[0]):
val = X[i] / pieces[i]
for _ in range(pieces[i]):
result[idx] = val
idx += 1
return result
@njit
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
simulated_deployment, stage_weights):
n_stage, N = X.shape
num_group = P // M
X_all = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
for j in range(N):
X_all[j] += X[i, j]
sort_idx = np.argsort(np.negative(X_all))
X_sorted = X[:, sort_idx]
unit_load = np.empty(N, dtype=np.float32)
for j in range(N):
unit_load[j] = X_all[j] / simulated_pieces[j]
flat_deployment = simulated_deployment.reshape(-1)
simulated_load = np.zeros(M, dtype=np.float32)
for i in range(flat_deployment.shape[0]):
simulated_load[i // (flat_deployment.shape[0] //
M)] += unit_load[flat_deployment[i]]
slice_vals = slice_values(X_all, simulated_pieces)
sorted_slices = np.sort(slice_vals)[::-1]
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
cumulative_slices_used = np.zeros(N, dtype=np.int32)
acc = 0
for i in range(N):
acc += simulated_pieces[sort_idx[i]]
cumulative_slices_used[i] = acc
group_boundary_indices = np.zeros(num_group, dtype=np.int32)
for i in range(1, num_group + 1):
for j in range(N):
if cumulative_slices_used[j] >= i * M:
group_boundary_indices[i - 1] = j
break
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
slices_used_per_group[0] = group_boundary_indices[0]
for i in range(1, num_group):
slices_used_per_group[
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
slices_used_per_group = M - slices_used_per_group
loads = np.zeros(M, dtype=np.float32)
pieces = np.zeros(N, dtype=np.int32)
num_remain_slice = P - N
current_idx = 0
for g in range(num_group):
window = X_sorted[:, current_idx:current_idx + 2 * M]
low = max(0, current_idx + M - N)
high = min(num_remain_slice, M - 1)
while (high - low) > 1:
mid = int((high + low) // 2)
keep = M - mid
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M,
stage_weights)
current_pieces = np.maximum(current_pieces, 1)
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
current_loads = loads + current_slice_sorted
current_max: np.float32 = np.max(current_loads)
current_min: np.float32 = np.min(current_loads)
current_slope = (current_max - current_min) / M
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
keep:])
if abs(current_slope) > abs(next_slope):
low = mid
else:
high = mid
S = high
keep = M - S
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M, stage_weights)
for i in range(keep):
pieces[sort_idx[current_idx + i]] = current_pieces[i]
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
loads += current_slice_sorted
loads = np.sort(loads)[::-1]
current_idx += keep
num_remain_slice -= S
return pieces
@njit
def compute_objective(deployment, X, pieces):
M, P = deployment.shape
loads = np.zeros(M)
for i in range(M):
for j in range(P):
expert = deployment[i, j]
if pieces[expert] == 0:
continue
loads[i] += X[expert] / pieces[expert]
mean_load = np.mean(loads)
max_load: np.float32 = np.max(loads)
obj = max_load / mean_load
return obj, loads
@njit
def auto_fix_new_placement(old_placement, new_placement):
"""
Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both
old_placement and new_placement remain in their original positions from old_placement.
New elements (unique to new_placement) will fill the remaining empty positions.
Args:
old_placement: Old deployment matrix with shape (num_ranks, num_experts)
new_placement: New deployment matrix to be fixed, must have the same shape as old_placement
Returns:
fixed_new: adjusted version of the new_placement matrix
"""
num_ranks, num_experts = old_placement.shape
fixed_new = np.empty_like(new_placement)
max_expert_old = old_placement.max() if num_experts > 0 else 0
max_expert_new = new_placement.max() if num_experts > 0 else 0
max_expert = max(max_expert_old, max_expert_new)
for rank_id in range(num_ranks):
old_row = old_placement[rank_id]
new_row = new_placement[rank_id]
index_array = np.full((max_expert + 1, num_experts),
-1,
dtype=np.int32)
count_array = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
pos = count_array[val]
index_array[val, pos] = idx
count_array[val] += 1
old_counter = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
old_counter[val] += 1
retain_elements = np.empty(num_experts, dtype=new_placement.dtype)
new_elements = np.empty(num_experts, dtype=new_placement.dtype)
retain_ptr = 0
new_ptr = 0
for val in new_row:
if val >= 0 and val <= max_expert and old_counter[val] > 0:
retain_elements[retain_ptr] = val
retain_ptr += 1
old_counter[val] -= 1
else:
new_elements[new_ptr] = val
new_ptr += 1
current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype)
for i in range(retain_ptr):
val = retain_elements[i]
if val >= 0 and val <= max_expert:
pos = count_array[val] - 1
if pos >= 0:
idx = index_array[val, pos]
current_fixed[idx] = val
count_array[val] -= 1
empty_indices = np.empty(num_experts, dtype=np.int32)
empty_ptr = 0
for idx in range(num_experts):
if current_fixed[idx] == -1:
empty_indices[empty_ptr] = idx
empty_ptr += 1
for i in range(new_ptr):
if i < empty_ptr:
current_fixed[empty_indices[i]] = new_elements[i]
fixed_new[rank_id] = current_fixed
return fixed_new
class FlashLB(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
self.par_history: Dict[int, float] = {}
self.hotness_window: Dict[int, deque[float]] = {}
self.max_stage_window = (config.max_stage_window if hasattr(
config, "max_stage_window") else 1)
self.buffer_expert_layer_num = (
config.buffer_expert_layer_num if hasattr(
config, "buffer_expert_layer_num") else 58)
self.threshold_ratio = (config.threshold_ratio if hasattr(
config, "threshold_ratio") else 0)
def compute_expert_hotness(self, num_of_expert: int,
deployment: np.ndarray, rank_load: np.ndarray):
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
deployment_flat = deployment.ravel()
rank_load_flat = rank_load.ravel()
np.add.at(hotness, deployment_flat, rank_load_flat)
return hotness
def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray):
n_stage, N = hotness.shape
if np.any(deployment < 0):
print(f"Invalid deployment with negative values: {deployment}")
raise ValueError("Deployment table contains negative values.")
counts = np.bincount(deployment.reshape(-1), minlength=N)
unit_hotness = np.divide(hotness,
counts,
out=np.zeros_like(hotness, dtype=float),
where=counts != 0)
stage_par = np.zeros(n_stage)
for i in range(n_stage):
stage_load = unit_hotness[i][deployment].sum(-1)
stage_par[i] = stage_load.max() / stage_load.mean()
return stage_par.mean()
def group_based_adaptive_bloating(self,
X,
P,
M,
stage_weights=None,
recorsive=False):
n_stage, N = X.shape
if stage_weights is None:
stage_weights = np.ones(n_stage, dtype=np.float32)
if recorsive:
(
simulated_deployment,
simulated_pieces,
) = self.group_based_adaptive_bloating(X,
P,
M,
stage_weights,
recorsive=False)
else:
simulated_pieces = compute_piece_counts(X, P, stage_weights)
simulated_deployment = jsq_placement(X, simulated_pieces, M,
stage_weights)
pieces = group_based_adaptive_bloating_kernel(
X.astype(np.float32),
P,
M,
simulated_pieces.astype(np.int32),
simulated_deployment.astype(np.int32),
stage_weights.astype(np.float32),
)
deployment = jsq_placement(X, pieces, M, stage_weights)
X_all = X.sum(0)
unit_load = np.divide(X_all,
pieces,
out=np.zeros_like(X_all, dtype=float),
where=pieces != 0)
load = unit_load[deployment].sum(-1)
sim_unit_load = X_all / simulated_pieces
sim_load = sim_unit_load[simulated_deployment].sum(-1)
if load.max() > sim_load.max():
return simulated_deployment, simulated_pieces
return deployment, pieces
def need_update(self, current_par, layer_id=0):
threshold = self.par_history.get(layer_id, 0.0)
return current_par >= self.threshold_ratio * threshold
def compute_stage_weight(self, hotness):
n_stage = hotness.shape[0]
stage_weights = np.zeros(n_stage)
for i in range(n_stage):
stage_weights[i] = hotness[i].sum()
stage_weights = stage_weights / stage_weights.max()
return stage_weights
def rebalance_layer(self, deployment, hotness, layer_id=0):
num_rank, expert_per_rank = deployment.shape
num_expert = np.unique(deployment.reshape(-1)).shape[0]
num_of_redundant_expert = num_rank * expert_per_rank - num_expert
current_par = self.compute_rank_load(deployment, hotness)
if not self.need_update(current_par, layer_id):
return deployment, current_par, current_par
stage_weights = self.compute_stage_weight(hotness)
new_deployment, _ = self.group_based_adaptive_bloating(
hotness,
num_expert + num_of_redundant_expert,
num_rank,
stage_weights,
recorsive=False,
)
if np.any(new_deployment < 0):
print(f"{new_deployment=}")
new_par = self.compute_rank_load(new_deployment, hotness)
return new_deployment, new_par, current_par
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
for layer in range(num_layer):
if layer not in self.hotness_window:
self.hotness_window[layer] = deque(
maxlen=self.max_stage_window)
hotness = self.compute_expert_hotness(num_expert,
deployment[layer],
rank_load[layer])
self.hotness_window[layer].append(hotness)
def compress_by_avg_pooling_fast_nd(self, arr, m):
n, d = arr.shape
idx = (np.arange(n) * m // n)
result = np.zeros((m, d))
counts = np.zeros((m, 1))
np.add.at(result, idx, arr)
np.add.at(counts, idx, 1)
return result / counts
def rebalance_experts(self, current_expert_table, expert_workload):
current_deployment = np.array(current_expert_table)
expert_workload = np.array(expert_workload)
expert_workload += 1
num_layer = expert_workload.shape[0]
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
self.register_hotness(current_deployment, expert_workload, num_layer,
num_expert)
new_deployment = current_deployment.copy()
layers_need_update = np.arange(num_layer)
new_par = np.zeros(layers_need_update.shape[0])
current_par = np.zeros(layers_need_update.shape[0])
for i, layer in enumerate(layers_need_update):
hotness = np.array(self.hotness_window[layer])
if hotness.shape[0] > self.max_stage_window:
hotness = self.compress_by_avg_pooling_fast_nd(
hotness, self.max_stage_window)
(
new_deployment[layer],
new_par[i],
current_par[i],
) = self.rebalance_layer(current_deployment[layer],
hotness,
layer_id=layer)
priority = new_par / current_par
priority_idx = np.argsort(priority)
priority_idx = priority_idx[priority[priority_idx] <
1][:self.buffer_expert_layer_num]
if np.all(expert_workload == 1):
for _, layer in enumerate(layers_need_update):
self.hotness_window[layer].pop()
return False, np.array([], dtype=int), current_deployment
change = len(priority_idx) > 0
if change:
for idx in priority_idx:
self.par_history[layers_need_update[idx]] = new_par[idx]
layers_need_update = priority_idx
deployment = current_deployment
for layer in layers_need_update:
deployment[layer] = auto_fix_new_placement(
current_deployment[layer], new_deployment[layer])
return change, layers_need_update, deployment
def generate_layered_experts(num_layers=58,
layer_shape=(32, 9),
expert_min=0,
expert_max=255):
"""
Generate expert deployment matrix meeting the following conditions:
- Total of num_layers layers
- Each layer has shape layer_shape (32,9)
- Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer
Args:
num_layers: Number of layers, default 58
layer_shape: Shape of a single layer, default (32,9)
expert_min: Minimum expert ID, default 0
expert_max: Maximum expert ID, default 255
Returns:
torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1])
"""
# 1. Basic parameter calculation
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
layer_total = layer_shape[0] * layer_shape[
1] # Total elements in a single layer: 32*9=288
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
assert layer_total >= expert_num, (
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
"cannot cover all experts")
# 3. Generate layers one by one
layers = []
for _ in range(num_layers):
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
full_experts = torch.arange(expert_min,
expert_max + 1,
dtype=torch.int64) # shape (256,)
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
extra_experts = torch.randint(expert_min,
expert_max + 1,
size=(extra_slots, ),
dtype=torch.int64) # shape (32,)
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
layer_flat = torch.cat([full_experts, extra_experts],
dim=0) # shape (288,)
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
shuffle_idx = torch.randperm(layer_flat.shape[0])
layer_shuffled = layer_flat[shuffle_idx]
# 3.4 Reshape to layer_shape (32,9)
layer = layer_shuffled.reshape(layer_shape)
layers.append(layer)
# 4. Stack all layers to get the final tensor
return torch.stack(layers, dim=0) # shape (58,32,9)
def warm_up():
exam_config = DynamicConfig()
exam_config.ep_worldsize = 32
exam_config.num_die_per_host = 16
algo = FlashLB(exam_config)
# Generate target tensor
expert_tensor = generate_layered_experts(num_layers=58,
layer_shape=(32, 9))
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))

View File

@@ -0,0 +1,30 @@
# Copyright # Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
import copy
import random
from .policy_abstract import DynamicConfig, EplbPolicy
random.seed(42)
class RandomLoadBalance(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
def rebalance_experts(self, current_expert_table, expert_workload):
new_table = copy.deepcopy(current_expert_table)
num_layers = len(current_expert_table)
for i in range(num_layers):
# randomly choose two card
# indices = random.sample(range(num_card), 2)
indices = [3, 1]
# swap redundant experts
expert_id_to_exchange = new_table[i][indices[0]][-1].clone()
new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1]
new_table[i][indices[1]][-1] = expert_id_to_exchange
return 1, [-i for i in range(num_layers)], new_table

View File

@@ -0,0 +1,205 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
import numpy
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import logger
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
class EplbUpdator:
def __init__(self, ascend_config, loader, eplb_process: EplbProcess,
process):
self.ascend_config = ascend_config
self.init_eplb(self.ascend_config.expert_map_path, process)
self.eplb_loader = loader
self.eplb_process = eplb_process
self.shared_dict = self.eplb_process.shared_dict
def set_adaptor(self, adaptor):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
self.global_expert_num = self.adaptor.global_expert_num
def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank()
self.num_expert_load_gather = 10
self.periodic_load_gather = True
self.num_iterations_eplb_update: torch.int64 = self.ascend_config.num_iterations_eplb_update
self.expert_map_path = expert_map_path
self.expert_map_record_path = self.ascend_config.expert_map_record_path
try:
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
self.num_expert_load_gather = self.num_iterations_eplb_update
self.periodic_load_gather = False
except Exception:
self.num_expert_load_gather = self.num_iterations_eplb_update
self.periodic_load_gather = False
self.expert_map_initialized = False
self.gate_eplb = self.ascend_config.gate_eplb
self.reqs = []
self.update_info_all = []
self.cur_iterations: torch.int64 = 0
self.num_wait_worker_iterations: torch.int64 = self.ascend_config.num_wait_worker_iterations
self.process = process
logger.info(
f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
def update_iteration(self):
self.cur_iterations += 1
if self.cur_iterations == (self.num_iterations_eplb_update + \
self.num_wait_worker_iterations + self.num_moe_layers):
if self.expert_map_record_path is not None:
self.adaptor._export_tensor_to_file(
self.shared_dict["expert_maps"],
self.expert_map_record_path)
self.adaptor.model.clear_all_moe_loads()
if not self.gate_eplb:
self.cur_iterations = 0
def get_update_info_flag(self):
return self.cur_iterations == (self.num_iterations_eplb_update +
self.num_wait_worker_iterations - 1)
def wakeup_eplb_worker_flag(self):
return self.cur_iterations == (self.num_iterations_eplb_update - 1)
def update_expert_weight_flag(self):
weight_update_counter = self.cur_iterations - (
self.num_iterations_eplb_update + self.num_wait_worker_iterations)
return (weight_update_counter >= 0
and weight_update_counter < self.num_moe_layers)
def get_init_expert_map(self):
try:
if not self.expert_map_initialized:
self.shared_dict[
"expert_maps"] = self.adaptor.get_init_expert_map_from_file(
self.num_moe_layers, self.expert_map_path)
self.expert_map_initialized = True
except Exception as e:
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}",
exc_info=True)
def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1)
def forward_before(self):
if self.update_expert_weight_flag():
(expert_send_info, expert_recv_info, updated_expert_map,
log2phy_map, layer_id) = self.update_info_all.pop(0)
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
updated_expert_map_this_rank = torch.from_numpy(
numpy.array(updated_expert_map))
self.eplb_loader.generate_expert_d2d_transfer_task(
expert_send_info, expert_recv_info,
updated_expert_map_this_rank,
layer_id + self.adaptor.num_dense_layers)
# set asynchronous stream for d2d expert weight update
self.reqs = []
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
def take_update_info_from_eplb_process(self):
# Batch after eplb process being triggered, get update info provided by eplb process
if self.get_update_info_flag():
self.update_info_all = self.eplb_process.block_update_q.get()
def forward_end(self):
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load(is_clear=True)
self.wakeup_eplb_worker()
if self.update_expert_weight_flag():
self.eplb_loader.update_expert_map_and_weight(self.reqs)
self.update_iteration()
def compute_and_set_moe_load(self, is_clear=False):
local_load = self.adaptor.get_rank_expert_workload()
self._gather_buffer = None
if dist.is_initialized():
self.world_size = dist.get_world_size()
self.device = local_load.device
if self._gather_buffer is None:
shape = (self.world_size, *local_load.shape)
self._gather_buffer = torch.empty(shape,
dtype=local_load.dtype,
device=self.device)
dist.all_gather_into_tensor(self._gather_buffer, local_load)
moe_load = self._gather_buffer.permute(1, 0, 2)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
else:
moe_load = local_load.unsqueeze(1)
self.shared_dict["moe_load"] = moe_load.cpu()
logger.debug(
f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}"
)
return moe_load
def warm_up_eplb(self):
self.get_init_expert_map()
self.compute_and_set_moe_load()
src_tensor = torch.empty((1, ), device=self.device)
self_rank = dist.get_rank()
comm_op_list = []
for dst_rank in range(self.world_size):
if dst_rank == self_rank:
continue
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
for src_rank in range(self.world_size):
if src_rank == self_rank:
continue
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
if comm_op_list:
reqs = dist.batch_isend_irecv(comm_op_list)
for req in reqs:
req.wait()
def shutdown(self):
"""
Clean up the EPLB process.
"""
if self.process.is_alive():
self.process.terminate()
self.process.join()
logger.info("[ModelRunner] EPLB process terminated")

77
vllm_ascend/eplb/utils.py Normal file
View File

@@ -0,0 +1,77 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register.
import types
import torch
def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
def get_log2phy_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
def get_all_expert_map(self, num_moe_layers):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
return torch.stack(all_loads, dim=0)
def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
return all_moe_loads
def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()
def model_register(model, model_config):
model.get_expert_map = types.MethodType(get_expert_map, model)
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
config = model_config.hf_config
if config.model_type == "qwen3_moe":
model.num_moe_layers = config.num_hidden_layers
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
num_dense_layers = config.first_k_dense_replace
model.num_moe_layers = config.num_hidden_layers - num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")

View File

@@ -21,7 +21,7 @@ def bgmv_shrink(inputs: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
return torch.ops._C.bgmv_shrink(
return torch.ops._C_ascend.bgmv_shrink(
inputs,
lora_a_weights,
lora_indices_tensor,
@@ -35,7 +35,7 @@ def bgmv_expand(inputs: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(
return torch.ops._C_ascend.bgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
@@ -52,9 +52,9 @@ def bgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
return torch.ops._C_ascend.bgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, output_tensor,
slice_offset, slice_size)
def sgmv_shrink(
@@ -69,9 +69,9 @@ def sgmv_shrink(
token_nums: int,
scaling: float,
):
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, scaling)
return torch.ops._C_ascend.sgmv_shrink(inputs, lora_a_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, scaling)
def sgmv_expand(inputs: torch.Tensor,
@@ -84,7 +84,7 @@ def sgmv_expand(inputs: torch.Tensor,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
return torch.ops._C.sgmv_expand(
return torch.ops._C_ascend.sgmv_expand(
inputs,
lora_b_weights,
lora_indices_tensor,
@@ -107,6 +107,7 @@ def sgmv_expand_slice(inputs: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
return torch.ops._C.sgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, slice_offset, slice_size)
return torch.ops._C_ascend.sgmv_expand(inputs, lora_b_weights,
lora_indices_tensor, seq_len_tensor,
output_tensor, slice_offset,
slice_size)

View File

@@ -11,12 +11,14 @@ if is_310p():
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
else:
from vllm_ascend.lora.punica_wrapper.lora_ops import (
bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm_ascend.lora.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm_ascend.lora.utils import refresh_all_lora_classes
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
@@ -31,6 +33,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
refresh_all_lora_classes()
def _shrink_prefill(
self,
@@ -338,13 +341,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
if lora_a_stacked.dim() == 2:
lora_a_stacked = lora_a_stacked.unsqueeze(0)
if lora_b_stacked.dim() == 2:
lora_b_stacked = lora_b_stacked.unsqueeze(0)
r = lora_a_stacked.size(-1)
r = lora_b_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r),
@@ -352,13 +349,8 @@ class PunicaWrapperNPU(PunicaWrapperBase):
device=x.device)
indices = self.sampler_indices
if indices.max() >= lora_a_stacked.size(0):
indices = torch.clamp(indices, 0, lora_a_stacked.size(0) - 1)
lora_a_reshaped = lora_a_stacked.transpose(1, 2)
lora_b_reshaped = lora_b_stacked.transpose(1, 2)
bgmv_shrink(x, lora_a_reshaped, buffer, indices, scale)
bgmv_expand(buffer, lora_b_reshaped, y, indices, add_inputs=True)
bgmv_shrink(x, lora_a_stacked, buffer, indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, indices, add_inputs=True)
y = y.view_as(y_org)

110
vllm_ascend/lora/utils.py Normal file
View File

@@ -0,0 +1,110 @@
from typing import Optional
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm.lora.layers.utils import _not_fully_sharded_can_replace
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
class AscendColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendColumnParallelLinear
class AscendMergedColumnParallelLinearWithLoRA(
MergedColumnParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendMergedColumnParallelLinear
class AscendRowParallelLinearWithLoRA(RowParallelLinearWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendRowParallelLinear
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
class AscendQKVParallelLinearWithLoRA(QKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is AscendQKVParallelLinear and len(
packed_modules_list) == 1
class AscendMergedQKVParallelLinearWithLoRA(MergedQKVParallelLinearWithLoRA):
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is AscendQKVParallelLinear
and len(packed_modules_list) == 3)
def refresh_all_lora_classes():
vllm.lora.utils._all_lora_classes.add(AscendColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedColumnParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendRowParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes.add(AscendQKVParallelLinearWithLoRA)
vllm.lora.utils._all_lora_classes.add(
AscendMergedQKVParallelLinearWithLoRA)

View File

@@ -23,7 +23,7 @@ from torch.library import Library
# Do NOT perform any real computation or allocate device memory.
#
# 2. Register your meta function using `register_meta_if_necessary`, providing:
# - The namespace (usually "_C" for custom ops)
# - The namespace (usually "_C_ascend" for custom ops)
# - The operator name (as registered in C++)
# - The Python meta function
# - (Optional) The overload name, if your op has overloads
@@ -39,7 +39,7 @@ from torch.library import Library
#
# For more details, see: https://pytorch.org/docs/stable/notes/extending.html#meta-tensors
lib = Library("_C", "IMPL")
lib = Library("_C_ascend", "IMPL")
def register_meta_if_necessary(ns: str, op_name: str, fn, overload: str = ""):
@@ -97,8 +97,9 @@ def sgmv_expand_meta(x: torch.Tensor, weight: torch.Tensor,
return y_out
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
register_meta_if_necessary("_C", "get_masked_input_and_mask",
register_meta_if_necessary("_C_ascend", "rotary_embedding",
rotary_embedding_meta)
register_meta_if_necessary("_C_ascend", "get_masked_input_and_mask",
get_masked_input_and_mask_meta)
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "bgmv_expand", bgmv_expand_meta)
register_meta_if_necessary("_C_ascend", "sgmv_expand", sgmv_expand_meta)

View File

@@ -4,23 +4,20 @@ import vllm_ascend.envs as envs_ascend
def register_model():
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401
from .qwen2_5_vl import \
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
ModelRegistry.register_model(
"Qwen3VLMoeForConditionalGeneration",
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration"
)
ModelRegistry.register_model(
"Qwen3VLForConditionalGeneration",
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration"
)
if envs_ascend.USE_OPTIMIZED_MODEL:
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
@@ -32,30 +29,32 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
)
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
ModelRegistry.register_model(
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")

File diff suppressed because it is too large Load Diff

View File

@@ -23,22 +23,20 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
class CustomDeepSeekShareHead(SharedHead):
@@ -65,6 +63,7 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
vllm_config = get_current_vllm_config()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -75,10 +74,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "shared_head"))
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
prefix=prefix)
def forward(
self,
@@ -103,8 +100,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
@@ -171,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata=None, # type: ignore
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
@@ -183,14 +178,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
class CustomDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
@@ -199,8 +186,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
@@ -215,4 +200,4 @@ class CustomDeepSeekMTP(DeepSeekMTP):
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2ForCausalLM
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass

View File

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
class AscendMLAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
mla_modules: AscendMLAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
qk_nope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=mla_modules.rotary_emb,
q_a_proj=mla_modules.q_a_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def mla_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
kv_cache, attn_metadata, need_gather_q_kv,
output)
return
def mla_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mla_forward",
op_func=mla_forward,
mutates_args=["output"],
fake_impl=mla_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
class AscendSparseFlashAttention(MultiHeadLatentAttention):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
sfa_modules: AscendSFAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
qk_nope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=sfa_modules.rotary_emb,
q_a_proj=sfa_modules.q_a_proj,
q_a_layernorm=sfa_modules.q_a_layernorm,
q_proj=sfa_modules.q_proj,
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
kv_a_layernorm=sfa_modules.kv_a_layernorm,
kv_b_proj=sfa_modules.kv_b_proj,
o_proj=sfa_modules.o_proj,
indexer=sfa_modules.indexer)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

File diff suppressed because it is too large Load Diff

View File

@@ -42,6 +42,8 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -291,6 +293,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
self.hidden_size, -1)
return out_weight
def pad_qkv_weight_scale_offset(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, 1)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head, :]
data2 = reshaped_data[:, :, self.
half_origin_hidden_size_per_attention_head:, :]
data1_paded = torch.nn.functional.pad(
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
data2_paded = torch.nn.functional.pad(
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1, 1)
return res
def pad_qkv_deq_scale_quant_bias(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head]
data2 = reshaped_data[:, :,
self.half_origin_hidden_size_per_attention_head:]
data1_paded = torch.nn.functional.pad(
data1, (0, self.half_pad_hidden_size_per_attention_head))
data2_paded = torch.nn.functional.pad(
data2, (0, self.half_pad_hidden_size_per_attention_head))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1)
return res
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
@@ -318,11 +354,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if ("attn.proj.weight" in name) and self.enable_pad:
if ("attn.proj.weight_scale" in name or
"attn.proj.weight_offset" in name) and self.enable_pad:
continue
elif ("attn.proj.deq_scale" in name
or "attn.proj.quant_bias" in name) and self.enable_pad:
continue
elif ("attn.qkv.weight_scale" in name
or "attn.qkv.weight_offset" in name) and self.enable_pad:
param.data = self.pad_qkv_weight_scale_offset(param.data)
elif ("attn.qkv.deq_scale" in name
or "attn.qkv.quant_bias" in name) and self.enable_pad:
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
elif ("attn.proj.weight" in name) and self.enable_pad:
param.data = self.pad_proj_weight(param.data)
if ("attn.qkv.weight" in name) and self.enable_pad:
elif ("attn.qkv.weight" in name) and self.enable_pad:
param.data = self.pad_qkv_weight(param.data)
if ("attn.qkv.bias" in name) and self.enable_pad:
elif ("attn.qkv.bias" in name) and self.enable_pad:
param.data = self.pad_qkv_bias(param.data)
loaded_params.add(name)
return loaded_params
@@ -450,12 +498,20 @@ class AscendQwen2_5_VLForConditionalGeneration(
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:

View File

@@ -1,6 +1,5 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
@@ -27,10 +26,19 @@ import torch_npu
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
try:
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
Qwen3VLConfig
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
Qwen3VLMoeConfig
except ImportError:
pass
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
get_act_and_mul_fn)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2_5_vl import (
@@ -38,10 +46,29 @@ from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
Qwen2_5_VLProcessingInfo)
from vllm.model_executor.models.utils import maybe_prefix
try:
from vllm.model_executor.models.qwen3_vl import (
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
from vllm.model_executor.models.qwen3_vl_moe import (
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
except ImportError:
Qwen3_VisionBlock = object
Qwen3_VisionPatchEmbed = object
Qwen3_VisionTransformer = object
Qwen3VLDummyInputsBuilder = object
Qwen3VLForConditionalGeneration = object
Qwen3VLMultiModalProcessor = object
Qwen3VLProcessingInfo = object
Qwen3VLMoeForConditionalGeneration = object
Qwen3VLMoeProcessingInfo = object
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
from vllm_ascend.utils import vllm_version_is
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
@@ -112,16 +139,14 @@ class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix)
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
@@ -321,6 +346,133 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
return x
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.matmul(
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
x = x + self.proj.bias
return x
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix, use_data_parallel)
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
x = x + self.mlp(self.norm2(x))
return x
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
def __init__(
self,
vision_config,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix,
use_data_parallel)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.patch_embed = AscendQwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
self.blocks = nn.ModuleList([
AscendQwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(vision_config.depth)
])
self.hidden_size_per_attention_head = dist_utils.divide(
self.hidden_size, self.num_heads)
def cal_cos_sin(self, rotary_pos_emb):
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
sin = rotary_pos_emb.sin()
cos_new = torch.cat((cos, cos), dim=-1)
sin_new = torch.cat((sin, sin), dim=-1)
cos_new = cos_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
sin_new = sin_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
return cos_new, sin_new
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
grid_thw_tensor = torch.tensor(grid_thw,
device=self.device,
dtype=torch.int32)
cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
cos, sin = self.cal_cos_sin(rotary_pos_emb)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(
layer_num)
deepstack_feature = self.deepstack_merger_list[
deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,
info=Qwen2_5_VLProcessingInfo,
@@ -332,12 +484,20 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
@@ -371,3 +531,101 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen3VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
if vllm_version_is("0.10.2"):
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel)
else:
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class AscendQwen3VLMoeForConditionalGeneration(
Qwen3VLMoeForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if vllm_version_is("0.10.2"):
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
else:
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

View File

@@ -40,6 +40,8 @@ from vllm.model_executor.models.qwen2_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -343,10 +345,18 @@ class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
vllm_config.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
vllm_config.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=vllm_config.quant_config,
prefix=maybe_prefix(prefix, "visual"),
)

View File

@@ -1,156 +0,0 @@
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import Qwen3Config
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
if quant_config is None:
return
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)
ALL_DECODER_LAYER_TYPES = {
"attention": CustomQwen3DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen3Model(Qwen2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen3Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = CustomQwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@@ -17,14 +17,14 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional, Union
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.utils import vllm_version_is
@@ -101,7 +98,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
@@ -120,7 +116,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)
return hidden_states
@@ -175,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
if vllm_version_is("0.10.2"):
self.mlp = Qwen3MoeSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@@ -189,60 +189,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -254,11 +200,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
self.num_redundant_experts = parallel_config.num_redundant_experts
else:
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
@@ -281,60 +224,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
@@ -357,7 +248,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
@@ -378,16 +268,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -0,0 +1,676 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
"""Inference-only Qwen3Next model."""
from collections.abc import Iterable
from typing import Optional
import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import RMSNormGated
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.fused_moe import FusedMoE
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.layernorm import \
GemmaRMSNorm as Qwen3NextRMSNorm
# yapf: enable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_mixer2 import \
mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.model_executor.models.qwen3_next import ( # isort: skip
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
fused_gdn_gating)
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.head_v_dim, self.conv_kernel_size, self.num_spec)
def __init__(
self,
config: Qwen3NextConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
self.projection_size_ba = self.num_v_heads * 2
self.in_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight, {
"weight_loader":
mamba_v2_sharded_weight_loader([
query_key_settings,
query_key_settings,
value_settings,
], self.tp_size, self.tp_rank)
})
# selective projection used to make dt, B and C input dependent
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size), )
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
dtype=torch.float32,
))
set_weight_attrs(self.A_log,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
norm_before_gate=True,
device="npu",
)
self.out_proj = RowParallelLinear(self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_masks = attn_metadata.spec_token_masks
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = (attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens +
attn_metadata.num_spec_decode_tokens)
num_accepted_tokens = attn_metadata.num_accepted_tokens
# 1. Set up dimensions for reshapes later
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
if spec_token_masks is not None:
spec_token_masks = spec_token_masks[:num_actual_tokens]
projected_states_qkvz, projected_states_ba = torch.split(
projected_states,
[
self.projection_size_qkvz // self.tp_size,
self.projection_size_ba // self.tp_size
],
dim=-1,
)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
(query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if spec_sequence_masks is not None:
if (attn_metadata.num_prefills == 0
and attn_metadata.num_decodes == 0):
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv[spec_token_masks]
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:attn_metadata
.num_decodes],
# validate_data=True,
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec)
beta = b.sigmoid()
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta))
if spec_sequence_masks is not None:
if (attn_metadata.num_prefills == 0
and attn_metadata.num_decodes == 0):
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g[:, spec_token_masks]
beta_spec = beta[:, spec_token_masks]
g_non_spec = g[:, ~spec_token_masks]
beta_non_spec = beta[:, ~spec_token_masks]
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
# 3. Recurrent attention
# 3.1: process the mutlti-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[:attn_metadata.
num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
))
else:
core_attn_out_spec, last_recurrent_state = None, None
# 3.2: process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
batch_size = initial_state.shape[0]
core_attn_out = []
last_recurrent_state = []
for b_idx in range(batch_size):
start, end = non_spec_query_start_loc[
b_idx], non_spec_query_start_loc[b_idx + 1]
cur_q = query_non_spec[:, start:end, ...]
cur_k = key_non_spec[:, start:end, ...]
cur_v = value_non_spec[:, start:end, ...]
cur_g = g_non_spec[:, start:end, ...]
cur_b = beta_non_spec[:, start:end, ...]
cur_state = initial_state[b_idx].unsqueeze(0)
(
cur_core_attn_out_non_spec,
cur_last_recurrent_state,
) = chunk_gated_delta_rule(
query=cur_q,
key=cur_k,
value=cur_v,
g=cur_g,
beta=cur_b,
initial_state=cur_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
core_attn_out.append(cur_core_attn_out_non_spec)
last_recurrent_state.append(cur_last_recurrent_state)
tar_dtype = core_attn_out[0].dtype
tar_device = core_attn_out[0].device
tar_shape = list(core_attn_out[0].shape)
tar_shape[1] = non_spec_query_start_loc[-1]
core_attn_out_non_spec = torch.empty(tar_shape,
dtype=tar_dtype,
device=tar_device)
for b_idx in range(batch_size):
cur_core_attn_out = core_attn_out[b_idx]
start, end = non_spec_query_start_loc[
b_idx], non_spec_query_start_loc[b_idx + 1]
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
))
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# Merge core attention output
if (spec_sequence_masks is not None
and core_attn_out_non_spec is not None):
core_attn_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
core_attn_out[:, spec_token_masks] = core_attn_out_spec
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
elif spec_sequence_masks is not None:
core_attn_out = core_attn_out_spec
else:
core_attn_out = core_attn_out_non_spec
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)')
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
def __init__(
self,
vllm_config: VllmConfig,
layer_type: str,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention":
self.linear_attn = CustomQwen3NextGatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=f'{prefix}.linear_attn')
elif self.layer_type == "full_attention":
self.self_attn = Qwen3NextAttention(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.self_attn',
)
else:
raise ValueError(f"Invalid layer_type {self.layer_type}")
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (self.layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3NextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.layer_scale = getattr(config, "layer_scale", False)
if self.layer_scale:
self.attn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
), )
@support_torch_compile
class CustomQwen3NextModel(Qwen3NextModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config: Qwen3NextConfig = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(prefix: str):
return CustomQwen3NextDecoderLayer(
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("in_proj", "in_proj_qkvz", 0),
("in_proj", "in_proj_ba", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name.startswith("mtp."):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# name = apply_attn_prefix(name, params_dict)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config
self.config = config
self.scheduler_config = scheduler_config
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3NextDecoderLayer)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts

View File

@@ -20,6 +20,7 @@ import torch
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.register_custom_ops # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (
@@ -34,19 +35,20 @@ class dummyFusionOp:
def register_dummy_fusion_op() -> None:
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C_ascend.fused_add_rms_norm = dummyFusionOp(
name="fused_add_rms_norm")
torch.ops._C_ascend.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
torch.ops._C_ascend.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
torch.ops._C_ascend.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")

View File

@@ -35,8 +35,10 @@ class AscendSiluAndMul(SiluAndMul):
from vllm_ascend.utils import is_310p
torch.ops.vllm.maybe_prefetch_mlp_down_proj(x)
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
torch.ops.vllm.maybe_wait_prefetch_done(out)
return out

View File

@@ -0,0 +1,539 @@
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors
from typing import Optional, Union
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
PAD_SLOT_ID = -1
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
initial_states=conv_states[cache_indices[i]]
if has_initial_state[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
x_ptr, # (batch, dim, seqlen)
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
intermediate_conv_window_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
dim: tl.constexpr,
seqlen: tl.constexpr,
state_len: tl.constexpr,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_inter_seq: tl.constexpr,
stride_inter_step: tl.constexpr,
stride_inter_dim: tl.constexpr,
stride_inter_win: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
SAVE_INTERMEDIATE: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = tl.load(num_accepted_tokens_ptr +
idx_seq) - 1
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 3:
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
) # [BLOCK_N]
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens - VAL >= 0)[:, None]
& (idx_tokens - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_state_ptrs_target = (conv_state_base +
(idx_tokens * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
# STEP 3: init accumulator
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
# mask_1d = (idx_token < seqlen) & (
# idx_feats < dim
# ) # token-index # feature-index
maskL = idx_feats < dim
maskR = tl.full(maskL.shape, False, tl.int1)
mask_1d = tl.where(idx_token < seqlen, maskL, maskR)
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq +
idx_token * stride_o_token + (idx_feats * stride_o_dim))
tl.store(o_ptrs, acc, mask=mask_1d)
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (intermediate_conv_window_ptr +
conv_state_batch_coord * stride_inter_seq +
idx_token * stride_inter_step +
idx_feats * stride_inter_dim)
if KERNEL_WIDTH >= 2:
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
if KERNEL_WIDTH >= 3:
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
if KERNEL_WIDTH >= 4:
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
def causal_conv1d_update_npu(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
intermediate_conv_window: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
[shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert (
conv_state.stride(-2) == 1
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch, ) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = (conv_state_indices.stride(0)
if conv_state_indices is not None else 0)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)
# prepare intermediate buffer strides if provided
if intermediate_conv_window is not None:
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
intermediate_conv_window.stride(0),
intermediate_conv_window.stride(1),
intermediate_conv_window.stride(2),
intermediate_conv_window.stride(3),
)
else:
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
intermediate_conv_window
if intermediate_conv_window is not None else x,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_inter_seq,
stride_inter_step,
stride_inter_dim,
stride_inter_win,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=128,
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
)
if unsqueeze:
out = out.squeeze(-1)
return out

View File

@@ -14,212 +14,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Optional
import os.path
from typing import Callable, Optional
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
if (use_int8_w8a8 or use_int4_w4a8):
assert w1_scale is not None and w2_scale is not None, \
"INT8 quantization requires weight scales."
w1_scale = w1_scale.to(torch.float32)
down_scale = [w2_scale]
down_output_dtype = w2_scale.dtype
else:
down_scale = None
down_output_dtype = None
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
use_int8_w8a8 or use_int4_w4a8)
gate_up_output = torch_npu.npu_grouped_matmul(
x=[permuted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32 if use_int8_w8a8 else None,
)[0]
if (use_int8_w8a8 or use_int4_w4a8):
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_output,
weight_scale=w1_scale,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
activated_output_scale = [activated_output_scale]
else:
activated_output = torch_npu.npu_swiglu(gate_up_output)
activated_output_scale = None
down_output = torch_npu.npu_grouped_matmul(
x=[activated_output],
weight=[w2],
scale=down_scale,
per_token_scale=activated_output_scale,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=down_output_dtype,
)[0]
moe_comm_method.unpermute(down_output, hidden_states)
return hidden_states
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
@@ -235,67 +55,7 @@ def unquantized_fused_moe_init_func(self, *args, **kwargs):
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
def forward_oot_v01011(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=1.0,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
self.transpose = True
def forward_oot(
@@ -321,7 +81,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -335,40 +95,35 @@ def forward_oot(
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
@@ -378,119 +133,88 @@ def process_weights_after_loading(self, layer):
class AscendFusedMoE(FusedMoE):
moe_counter = -1
def __init__(
self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_fator: float = 1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
else:
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
routed_scaling_fator,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(
self.expert_map != -1) if self.expert_map is not None else
self.global_num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
# TODO: Can we refactor this logic to model_runner?
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
if self.moe_config.ep_size < 16:
moe_comm_method_name = "allgathercommimpl"
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
@@ -514,6 +238,12 @@ class AscendFusedMoE(FusedMoE):
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:
self.moe_load += expert_tokens if group_list_type else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
@@ -521,11 +251,118 @@ class AscendFusedMoE(FusedMoE):
return final_hidden_states
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
AscendFusedMoE.__init__(self, **kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(self.shared_expert_stream,
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
else:
UnquantizedFusedMoEMethod.forward_oot = forward_oot
UnquantizedFusedMoEMethod.forward_oot = forward_oot

218
vllm_ascend/ops/fla.py Normal file
View File

@@ -0,0 +1,218 @@
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
# mypy: ignore-errors
import torch
import torch.nn.functional as F
import triton
from vllm.model_executor.layers.fla.ops.layernorm_guard import \
layer_norm_fwd_kernel
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N, )
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
if not is_rms_norm else None)
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.npu.device(x.device.index):
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = F.normalize(query, p=2, dim=-1)
key = F.normalize(key, p=2, dim=-1)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]
batch_size, sequence_length, num_heads, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
tot_heads = num_heads + pad_size
scale = 1 / (query.shape[-1]**0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) -
g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -(
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
k_head_dim, v_head_dim).to(value) if
initial_state is None else initial_state.to(value))
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size,
chunk_size,
dtype=torch.bool,
device=query.device),
diagonal=1)
# for each chunk
for i in range(0, tot_heads // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) *
decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
(k_i *
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
-1, -2) @ v_new)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
core_attn_out.shape[1], -1,
core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :num_heads]
core_attn_out = core_attn_out.transpose(1,
2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state

View File

@@ -19,13 +19,9 @@ import os
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
@@ -39,70 +35,16 @@ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p)
def unified_fused_experts_eager(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
w1_scale: Optional[torch.Tensor] = None,
w1_scale_bias: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w2_scale_bias: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=with_quant)
expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=results["group_list"],
dynamic_scale=results.get("dynamic_scale"),
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"),
with_quant=with_quant)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states
get_rm_router_logits_state, is_310p,
vllm_version_is)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -115,6 +57,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
get_ascend_config()
self.dynamic_eplb = get_ascend_config().dynamic_eplb
try:
device_group = get_mc2_group().device_group
@@ -182,17 +125,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
return unified_fused_experts_eager(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
shared_experts=shared_experts,
mc2_mask=kwargs.get(
"mc2_mask", None),
with_quant=False)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map,
shared_experts=shared_experts,
need_trans=True,
dynamic_eplb=self.dynamic_eplb)
class AscendFusedMoE(FusedMoE):
@@ -290,42 +235,67 @@ class AscendFusedMoE(FusedMoE):
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
if expert_map_path and os.path.exists(expert_map_path):
# moe expert load balance
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
self.global_num_experts)
self.local_num_experts, self.expert_map = \
expert_load_balancer.get_rank_placement_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.global_redundant_expert_num = \
expert_load_balancer.get_global_redundant_expert_num()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# Create a tensor of size num_experts filled with -1
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(self.expert_map != -1)
if self.expert_map is not None else num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
if vllm_version_is("0.10.2"):
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
else:
moe = FusedMoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
)
self.moe_config = moe
# TODO: The self.moe_config.tp_size here is not correct, fixme soon
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
@@ -337,6 +307,11 @@ class AscendFusedMoE(FusedMoE):
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
self.moe_load = None
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
@@ -354,34 +329,27 @@ class AscendFusedMoE(FusedMoE):
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.token_dispatcher = None
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
setup_moe_comm_method(self.moe_config)
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
def forward(self,
hidden_states: torch.Tensor,
@@ -391,8 +359,7 @@ class AscendFusedMoE(FusedMoE):
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False,
_metadata_for_padding: Optional[MetadataForPadding] = None):
replace_allreduce: bool = False):
assert self.quant_method is not None
@@ -401,10 +368,7 @@ class AscendFusedMoE(FusedMoE):
else:
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
@@ -413,74 +377,16 @@ class AscendFusedMoE(FusedMoE):
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
mc2_mask = forward_context.mc2_mask
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if forward_context.sp_enabled:
replace_allreduce = True
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
max_tokens_across_dp = forward_context.max_tokens_across_dp
if num_tokens < max_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_dp_cpu)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
enable_shared_expert_dp=self.enable_shared_expert_dp,
rm_router_logits=self.rm_router_logits,
replace_allreduce=replace_allreduce,
gate=gate)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
@@ -503,53 +409,27 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=None,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
group_list_type = None
if shared_experts:
if isinstance(e_hidden_states, tuple):
if isinstance(e_hidden_states,
tuple) and len(e_hidden_states) == 2:
e_hidden_states, shared_hidden_states = e_hidden_states
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce and not self.enable_shared_expert_dp):
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
final_hidden_states = get_dp_group().all_reduce(
e_hidden_states)
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = data_parallel_reduce_scatter(
e_hidden_states, dim=0)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
else:
final_hidden_states = e_hidden_states
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 3:
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
]:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.dynamic_eplb and group_list_type is not None:
self.moe_load += expert_tokens if group_list_type else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=e_hidden_states,
reduce_results=(not self.all_reduce_merge))
if shared_experts:
return final_hidden_states, shared_hidden_states

View File

@@ -15,50 +15,124 @@
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, cast
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
class AddRMSNormW8A8Quant(RMSNorm):
# Fuse AddRmsNorm and W8A8 quantization ops together
def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
if layer is not None and not is_310p():
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
else:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
@property
def next_need_quant_fusion_linear(self):
try:
forward_context = get_forward_context()
if not forward_context.addrmsnorm_quant_fusion_enabled or \
forward_context.layer_idx == forward_context.num_hidden_layers:
return None
except AssertionError:
return None
next_linear = None
model_instance = forward_context.model_instance
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
next_linear = None
if fusion_linear == "qkv_dense":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up_dense"
elif fusion_linear == "gate_up_dense":
next_linear = model_instance.model.layers[
layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv_dense"
# if prefetch_mlp_weight enabled, following accumulation operation
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
next_linear = None
return next_linear
class AscendQuantRMSNorm(AscendRMSNorm):
def __init__(
self,
hidden_size: int,
layer: torch.nn.Module,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.layer = layer
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward(
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
class AscendRMSNorm(RMSNorm):
class AscendGemmaRMSNorm(GemmaRMSNorm):
def forward_oot(
self,
@@ -73,13 +147,13 @@ class AscendRMSNorm(RMSNorm):
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
x, residual, 1.0 + self.weight, self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x

View File

@@ -1,45 +1,159 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
To customize linear communication groups or forward of classes in this file,
extend new linear operations in linear_op.py.
The classes in this file should not be modified, including AscendQKVParallelLinear,
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
AscendRowParallelLinear and AscendColumnParallelLinear.
"""
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
LinearBase,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (
get_mlp_tensor_model_parallel_rank,
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
from vllm_ascend.ops.linear_op import (get_column_parallel_op,
get_row_parallel_op)
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
class AscendLinearBase(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
class AscendQKVParallelLinear(QKVParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op, _, tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
return super().forward(input_)
class AscendMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -48,73 +162,46 @@ class AscendMlpColumnParallelLinear(ColumnParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.output_sizes = output_sizes
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
return super().forward(input_)
class AscendMlpRowParallelLinear(RowParallelLinear):
class AscendRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -133,28 +220,25 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
if prefix.find("down_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
self.custom_op, self.tp_rank, self.tp_size = get_row_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@@ -184,66 +268,22 @@ class AscendMlpRowParallelLinear(RowParallelLinear):
else:
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
is_prefill: bool = True,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.enable_mlp_optimze:
tp_rank = get_mlp_tensor_model_parallel_rank()
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_mlp_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
# output = output[:num_tokens,:]
# dispose_tensor(output_parallel)
else:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
if self.custom_op is not None:
return self.custom_op.apply(input_)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
return super().forward(input_)
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
class AscendColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
@@ -252,58 +292,76 @@ class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
self.custom_op, self.tp_rank, self.tp_size = get_column_parallel_op(
disable_tp, prefix, self)
# TODO(realliujiaxu): Replace the initialization code below with super().__init__ after linear of vllm supports custom comm group
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendMlpColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
# Matrix multiply.
assert self.quant_method is not None
if self.enable_mlp_optimze:
input2_ = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input2_, bias)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if self.custom_op is not None:
return self.custom_op.apply(input_)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
return super().forward(input_)

View File

@@ -0,0 +1,459 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops).
Current class inheritance structure:
CustomTensorParallelOp
├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp
│ ├── DenseOptimMergedColumnParallelOp
│ └── DenseOptimQKVParallelOp
└── CustomRowParallelOp
├── MLPRowParallelOp
├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp
└── DenseOptimRowParallelOp
How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
3. Override the apply method according to requirements, which will replace the original linear.forward
4. Add selection logic for MyColumnParallelOp in the get_column_parallel_op method, typically based on prefix and configuration judgments
Row parallel op follows a similar approach - inherit from RowColumnParallelOp and register the new class in get_row_parallel_op.
"""
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch_npu
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import split_tensor_along_last_dim
from vllm.distributed.parallel_state import get_tp_group
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
class CustomTensorParallelOp:
def __init__(self, layer):
self.layer = layer
self.bias = None
self.skip_bias_add = None
self.return_bias = None
self.quant_method = None
# Custom communication group, while determining weight sharding
@property
def comm_group(self):
return get_tp_group()
@property
def tp_rank(self):
return self.comm_group.rank_in_group
@property
def tp_size(self):
return self.comm_group.world_size
# Update the attributes required by apply(), obtaining them from the layer.
# Call this after the layer completes its initialization, specifically at the end of layer.init().
def update_attrs(self):
if hasattr(self.layer, "bias"):
self.bias = self.layer.bias
self.skip_bias_add = self.layer.skip_bias_add
self.return_bias = self.layer.return_bias
self.quant_method = self.layer.quant_method
self.prefix = self.layer.prefix
def apply_impl(self, input_):
raise NotImplementedError
# Replace layer.forward to customize the layer computation process.
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if not self.return_bias:
return output
return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp):
def __init__(self, layer):
super().__init__(layer)
self.gather_output = None
def update_attrs(self):
super().update_attrs()
self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomTensorParallelOp):
def __init__(self, layer):
super().__init__(layer)
self.reduce_results = None
self.input_is_parallel = None
self.input_size_per_partition = None
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
self.input_size_per_partition = self.layer.input_size_per_partition
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
if dense_optim_enable():
torch.ops.vllm.maybe_prefetch_mlp_gate_up_proj(output, self.prefix)
if not self.return_bias:
return output
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_parallel = self.comm_group.all_gather(input_, 0)
output = self.quant_method.apply(self.layer, input_parallel, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceMergedColumnParallelOp(CustomColumnParallelOp):
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class SequenceQKVParallelOp(CustomColumnParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
layer_num = self.prefix.split('.')[2]
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
input_, layer_num != '0')
output_parallel = self.quant_method.apply(self.layer, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = self.comm_group.all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_mlp_tp_group()
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.layer.bias
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output = self.comm_group.reduce_scatter(output_parallel, 0)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class OProjRowParallelOp(CustomRowParallelOp):
def __init__(self, layer):
super().__init__(layer)
@property
def comm_group(self):
return get_otp_group()
def apply_impl(
self,
input_: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Prepare tensors for all-to-all communication
local_batch_size = input_parallel.size(0)
chunk_size = self.input_size_per_partition
total_batch_size = local_batch_size * self.tp_size
# Reshape tensor for efficient cross-device transfer:
# [batch, dim] -> [tp_size, batch, chunk] -> flattened
send_buf = (input_parallel.reshape(-1,
self.tp_size, chunk_size).transpose(
0, 1).contiguous().view(-1))
# Create receive buffer
recv_buf = torch.empty(total_batch_size * chunk_size,
dtype=input_parallel.dtype,
device=input_parallel.device)
# Perform all-to-all communication
dist.all_to_all_single(recv_buf,
send_buf,
group=self.comm_group.device_group)
input_parallel = recv_buf.view(total_batch_size, chunk_size)
# Only fuse bias add for rank 0 to avoid duplicate bias addition in TP>1
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
# otp-specific: Combine partial results across devices
output = self.comm_group.reduce_scatter(output_parallel, dim=0)
output = output.view(input_.shape[0], self.layer.output_size)
# Handle bias return based on configuration
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.input_size_per_partition = self.layer.input_size_per_partition
class MatmulAllreduceRowParallelOp(CustomRowParallelOp):
_HCOMM_INFO = None
def __init__(self, layer):
super().__init__(layer)
self.hcomm_info = self.get_hcomm_info(self.comm_group.device_group)
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
"""Calculate the output tensor of forward by considering
fusing communication and computation."""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
assert self.quant_method is not None
output = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@classmethod
def get_hcomm_info(cls, group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group."""
if cls._HCOMM_INFO is not None:
return cls._HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
cls._HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)
else:
cls._HCOMM_INFO = group.get_hccl_comm_name(rank)
return cls._HCOMM_INFO
def update_attrs(self):
super().update_attrs()
self.weight_t = self.layer.weight.t()
class SequenceRowParallelOp(CustomRowParallelOp):
def __init__(self, layer, prefix):
super().__init__(layer)
self.prefix = prefix
def apply_impl(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Linear layer with column parallelism.
Implemented multiple optimization projects for dense models, such as FlashComm and
communication-computation fusion.
"""
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
assert self.quant_method is not None
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.tp_size == 1 or not self.reduce_results:
output = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
else:
output_parallel = self.quant_method.apply(self.layer,
input_parallel,
bias=bias_)
output = torch.ops.vllm.maybe_pad_and_reduce(output_parallel)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def update_attrs(self):
super().update_attrs()
self.input_is_parallel = self.layer.input_is_parallel
self.reduce_results = self.layer.reduce_results
def get_column_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPColumnParallelOp, SequenceMergedColumnParallelOp,
SequenceQKVParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[
MLPColumnParallelOp,
SequenceMergedColumnParallelOp,
SequenceQKVParallelOp,
]] = None
if "gate_up_proj" in prefix and mlp_tp_enable():
custom_op = MLPColumnParallelOp(layer)
elif "gate_up_proj" in prefix and enable_sp():
custom_op = SequenceMergedColumnParallelOp(layer)
elif enable_sp():
custom_op = SequenceQKVParallelOp(layer, prefix)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_row_parallel_op(
disable_tp, prefix, layer
) -> Tuple[Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]], int, int]:
if disable_tp:
return None, 0, 1
custom_op: Optional[Union[MLPRowParallelOp, OProjRowParallelOp,
MatmulAllreduceRowParallelOp,
SequenceRowParallelOp]] = None
if "down_proj" in prefix and mlp_tp_enable():
custom_op = MLPRowParallelOp(layer)
elif "o_proj" in prefix and oproj_tp_enable():
custom_op = OProjRowParallelOp(layer)
elif matmul_allreduce_enable():
custom_op = MatmulAllreduceRowParallelOp(layer)
elif enable_sp():
custom_op = SequenceRowParallelOp(layer, prefix)
if custom_op is not None:
return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size

View File

View File

@@ -1,5 +1,7 @@
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import torch.distributed
import torch.distributed as dist
@@ -60,3 +62,52 @@ def async_all_to_all(input_,
group=group,
async_op=True)
return input_, a2a_out, handle
def _gather_along_first_dim(input_, group, output_split_sizes=None):
"""Gather tensors and concatenate along the first dimension.
Args:
input_tensor (torch.Tensor):
A tensor to be gathered.
output_split_sizes (List[int], optional):
A list specifying the sizes of the output splits along the first dimension.
If None, equal splitting is assumed. Default: None.
Returns:
torch.Tensor: Gathered tensor.
"""
world_size = torch.distributed.get_world_size(group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output
def gather_from_sequence_parallel_region(
input_,
group,
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)

View File

@@ -0,0 +1,459 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import vllm_version_is
class FusedMoEPrepareAndFinalize(ABC):
"""
Abstract base class for MoE (Mixture-of-Experts) tensor preparation and finalization
in distributed environments. Subclasses implement specific communication strategies
(e.g., AllGather, All2All, MC2, Naive Multicast) to handle tensor padding, slicing,
broadcasting, and reduction across TP/DP/EP groups.
Attributes:
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@abstractmethod
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
- Slicing across tensor-parallel ranks
- Broadcasting across data-parallel ranks
- Recomputing router logits if needed
Args:
hidden_states (torch.Tensor): Input features, shape [num_tokens, hidden_size]
router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts]
enable_shared_expert_dp (bool): Skip DP communication for shared experts
rm_router_logits (bool): Discard input router_logits and recompute via gate
replace_allreduce (bool): Bypass default all-reduce behavior
gate (nn.Module, optional): Gate network to recompute router_logits if needed
Returns:
Tuple of:
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
- Reducing or scattering across DP ranks
- Unpadding to original token count
- Applying all-reduce across TP/EP if requested
Args:
hidden_states (torch.Tensor): MoE layer output, possibly padded or sliced
reduce_results (bool): Whether to apply all-reduce across TP/EP groups
Returns:
torch.Tensor: Final output with shape [original_num_tokens, hidden_size]
"""
raise NotImplementedError("Finalize function not implemented.")
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using MC2 (Memory-Centric Communication).
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""
Restore original TP configuration.
vLLM flattens TP and DP into a single dimension; this method recovers
the true TP world size and rank for correct tensor slicing.
"""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
2. Pad `hidden_states` and `router_logits` to target length if needed.
3. If TP > 1, split tensors along token dimension and select current TP rank's slice.
4. Split and return corresponding `mc2_mask`.
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask), possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
if self.tp_size > 1:
# Also slice mc2_mask
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
mc2_mask = split_mc2_mask[self.tp_rank]
if not self.replace_allreduce:
self.num_tokens, _ = hidden_states.shape
target_pad_length = forward_context.padded_num_tokens
pad_size = target_pad_length - self.num_tokens
# Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
self.split_hidden_states = split_hidden_states # Save for finalize
return hidden_states, router_logits, mc2_mask
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices from all TP ranks to reconstruct full tensor.
2. Unpad to original token count if padding was applied.
3. Return tensor with shape [original_num_tokens, hidden_size].
Skips communication and unpadding if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
# All-gather across TP group
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
# Unpad if necessary
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-to-All style slicing.
Similar to MC2 but does not use mc2_mask; instead pads to TP size for uniform slicing.
Will be used when num_tokens exceed mc2's limitation (512 tokens/rank).
"""
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self._restore_tp_across_dp()
def _restore_tp_across_dp(self):
"""Restore original TP configuration (same as MC2)."""
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
2. If TP > 1, split along token dim and select current TP rank's slice.
3. Save splits for later all-gather in finalize.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None) — no mask used in All2All.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
if not (self.replace_allreduce or self.enable_shared_expert_dp):
self.num_tokens, _ = hidden_states.shape
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
self.split_hidden_states = split_hidden_states
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
2. Unpad to original token count.
3. Return [original_num_tokens, hidden_size] tensor.
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
"""
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
dist.all_gather(list(self.split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = torch.cat(self.split_hidden_states, dim=0)
# TODO: It is a quick bugfix for the memory explosion issue in eager mode.
# If the cache is not cleared after `self.split_hidden_states` is created,
# it can lead to the memory explosion in eager mode.
del self.split_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
return hidden_states
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-Gather + Reduce-Scatter.
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
Uses `max_tokens_across_dp` from forward_context for padding alignment.
"""
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
2. Pad local tensors to that size.
3. All-gather across DP group to form global input tensor.
4. Optionally recompute router_logits using gate if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
if not rm_router_logits:
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
if rm_router_logits:
router_logits, _ = gate(hidden_states) # Recompute globally
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
2. Slice to original local token count.
3. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
Returns:
Tensor with shape [original_local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states
class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using Naive Multicast (point-to-point broadcast).
Will be used in prefill when using allgather in decode. Each DP rank broadcasts its slice to all others.
Uses `cu_tokens_across_dp_cpu` (cumulative tokens) to locate slice boundaries.
"""
def _naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
"""
Naive multicast implementation:
1. Create global buffer sized by total tokens across DP.
2. Current rank copies its slice into its designated buffer region.
3. Each rank broadcasts its slice to all others via P2P.
Args:
x (torch.Tensor): Local tensor [local_tokens, hidden_size]
cu_tokens_across_dp_cpu (torch.Tensor): Cumulative token counts per DP rank
Returns:
torch.Tensor: Global tensor [total_tokens, hidden_size]
"""
assert len(x.shape) == 2, "Input must be 2D [tokens, features]"
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
# Copy local slice into buffer
start = 0 if self.moe_config.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
buffer[start:end, :].copy_(x)
# Broadcast each slice to all ranks
for idx in range(self.moe_config.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Preparation steps:
1. Fetch cumulative token boundaries from forward context.
2. Multicast hidden_states and router_logits to form global tensors.
3. Optionally recompute router_logits globally if `rm_router_logits=True`.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
if vllm_version_is("0.10.2"):
self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
else:
self.cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_sp(1)
hidden_states = self._naive_multicast(hidden_states,
self.cu_tokens_across_dp_cpu)
if rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self._naive_multicast(
router_logits, self.cu_tokens_across_dp_cpu)
return hidden_states, router_logits, None
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert:
- All-reduce across DP
- Slice to current rank's token range using cu_tokens_across_dp_cpu
2. If `reduce_results=True` and TP/EP > 1, apply tensor_model_parallel_all_reduce.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
start = 0 if self.moe_config.dp_rank == 0 else self.cu_tokens_across_dp_cpu[
self.moe_config.dp_rank - 1]
end = self.cu_tokens_across_dp_cpu[self.moe_config.dp_rank]
hidden_states = get_dp_group().all_reduce(
hidden_states) # Sum across DP
hidden_states = hidden_states[start:end, :]
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
return hidden_states

View File

@@ -0,0 +1,273 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
FusedMoEPrepareAndFinalizeWithAll2All,
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
TokenDispatcherWithMoge)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.mc2_mask = None
self.token_dispatcher = self._get_token_dispatcher()
self.fused_moe_prepare_finalize = self._get_fused_moe_prepare_finalize(
)
def prepare(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
rm_router_logits: bool = False,
replace_allreduce: bool = False,
gate=None) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, router_logits, mc2_mask = self.fused_moe_prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
rm_router_logits, replace_allreduce, gate)
self.mc2_mask = mc2_mask
return hidden_states, router_logits
def finalize(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
hidden_states = self.fused_moe_prepare_finalize.finalize(
hidden_states, reduce_results)
return hidden_states
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False):
# Check constraints
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=self.mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
need_trans=need_trans)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output)
if dynamic_eplb:
return (final_hidden_states, group_list_type, expert_tokens)
return final_hidden_states
@abstractmethod
def _get_token_dispatcher(self):
raise NotImplementedError(
"_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_fused_moe_prepare_finalize(self):
raise NotImplementedError(
"_get_fused_moe_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _get_token_dispatcher(self):
if self.model_type == "PanguProMoE":
return TokenDispatcherWithMoge(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
else:
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithMC2()
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
class NaiveMulticastCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_fused_moe_prepare_finalize(self):
return FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)

View File

@@ -18,22 +18,52 @@ from typing import Optional
import torch
import torch_npu
from torch.nn.functional import pad
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import dispose_tensor, is_310p
def cumsum_group_list(group_list: torch.Tensor,
group_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if group_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {group_list_type}"
)
if group_list_type == 0:
return group_list
if group_list_type == 1:
return group_list.cumsum(dim=0)
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
cumsum_group_list[start:end] = tokens[i]
return cumsum_group_list
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
w2_scale_bias: torch.Tensor = None,
fusion: bool = False) -> torch.Tensor:
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
@@ -47,33 +77,40 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
w1_scale = w1_scale.to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
if fusion:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -92,29 +129,37 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias]
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
if fusion:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
x_scale=pertoken_scale)
else:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale.to(w2_scale.dtype)],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -127,17 +172,22 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states
def unquant_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
w1 = w1.transpose(1, 2)
def unquant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None,
need_trans: bool = True) -> torch.Tensor:
if need_trans:
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
@@ -155,7 +205,6 @@ def unquant_apply_mlp(
if topk_scales is not None:
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
@@ -178,7 +227,9 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False) -> torch.Tensor:
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
@@ -189,11 +240,13 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias)
w2_scale_bias=w2_scale_bias,
fusion=fusion)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales)
topk_scales=topk_scales,
need_trans=need_trans)

View File

@@ -22,42 +22,17 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.distributed.tensor_parallel import \
gather_from_sequence_parallel_region
from vllm_ascend.ops.comm_utils import async_all_to_all
from vllm_ascend.ops.moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC):
@@ -90,9 +65,9 @@ class MoETokenDispatcher(ABC):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -158,6 +133,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"expert_token_nums_type": 0,
}
stage1_kwargs = {
@@ -189,9 +165,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -215,6 +191,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
if self.with_quant:
if shared_experts is not None:
share_up_out, _ = shared_experts.gate_up_proj(
(quantized_x_for_share, dynamic_scale_for_share))
shared_gate_up, shared_dequant_scale = share_up_out[
0], share_up_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
@@ -224,7 +205,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 1
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
@@ -291,6 +272,16 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
# these values are no longer used, so they need to be set to None for memory release.
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.topk_ids = None
self.topk_weights = None
self.mc2_mask = None
self.expert_map = None
if self.shared_experts is None:
return hidden_states
else:
@@ -300,6 +291,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
self.shared_act = None
self.shared_experts = None
self.swiglu_out_scale = None
return hidden_states, shared_hidden_states
@@ -328,9 +322,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -338,8 +332,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
device = hidden_states.device
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
@@ -353,144 +345,65 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(
num_tokens, device=device,
dtype=torch.int64).unsqueeze(1).expand(-1,
self.top_k).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
# Filter valid token-expert pairs
self.mask = local_experts_flat != -1
filtered_weights = torch.where(
self.mask, weights_flat,
torch.zeros_like(weights_flat)).to(dtype)
filtered_experts = torch.where(
self.mask, local_experts_flat,
torch.full_like(local_experts_flat,
self.num_experts_local)).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(self.num_experts_local + 1,
device=device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
ones)
token_counts = token_counts[:self.num_experts_local]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[self.sorted_token_indices]
if self.with_quant:
group_list_type = 1
expert_tokens = token_counts
else:
expert_tokens = torch.cumsum(token_counts,
dim=0,
dtype=torch.int64)
group_list_type = 0
global_num_experts = len(expert_map)
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=active_num)
first_expert_idx = 0
last_expert_idx = self.num_experts_local
global_num_experts = self.num_experts_local
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, self.num_experts_local)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
sorted_hidden_states, self.expanded_row_idx, expert_tokens, pertoken_scale = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant else -1,
))
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
"dynamic_scale": pertoken_scale if self.with_quant else None,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.original_shape is not None
dtype = hidden_states.dtype
device = hidden_states.device
if self.expert_map is not None:
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
weighted_down_out = hidden_states * \
self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros(*self.original_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# This created multiple NaN and index_add_ will mix them up which harms accuracy
# remove this mask and filter after it being fixed
num_valid_tokens = self.mask.sum()
valid_token_mask = torch.arange(
0, self.sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, self.sorted_token_indices,
valid_output)
else:
if self.with_quant:
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=self.topk_weights,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(
self.original_shape)
else:
scales = torch.ones_like(
self.topk_weights
) if self.apply_router_weight_on_input else self.topk_weights
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=scales,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
# these values are no longer used, so they need to be set to None for memory release.
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.expanded_row_idx = None
return final_hidden_states
# mypy: disable-error-code="override"
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
class TokenDispatcherWithMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.local_num_experts = self.num_experts // self.ep_size
self.local_num_group = self.top_k // self.ep_size
self.bsz = None
def token_dispatch(self,
@@ -501,23 +414,12 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
@@ -551,7 +453,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
self.bsz, self.top_k // self.ep_size, -1).sum(1)
return final_hidden_states
@@ -613,9 +515,9 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
@@ -681,9 +583,14 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
output = self._combine_postprocess(permutated_local_input_tokens)
# these values are no longer used, so they need to be set to None for memory release.
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
self.topk_weights = None
self.reversed_local_input_permutation_mapping = None
self.reversed_global_input_permutation_mapping = None
self.global_input_tokens_local_experts_indices = None
return output
@@ -745,6 +652,10 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
else:
# TODO: This full synchronization can be a performance bottleneck.
# A more granular sync (e.g., blocking D2H copies) should be investigated.
torch.npu.synchronize()
return num_tokens_per_local_expert

View File

@@ -0,0 +1,201 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
def _maybe_chunk_residual_impl(x: torch.Tensor,
residual: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return residual
if x.size(0) != residual.size(0):
sp_enabled = forward_context.sp_enabled
assert sp_enabled is True, ("Currently, this situation only occurs "
"when sp is enabled")
pad_size = forward_context.pad_size
if pad_size > 0:
residual = F.pad(residual, (0, 0, 0, pad_size))
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
residual = torch.chunk(residual, tp_size, dim=0)[tp_rank]
return residual
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return x
sp_enabled = forward_context.sp_enabled
if sp_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
if pad_size > 0:
x = x[:-pad_size, :]
return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return tensor_model_parallel_all_reduce(x)
sp_enabled = forward_context.sp_enabled
if sp_enabled:
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
prefix: str) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not forward_context.prefetch_mlp_enabled:
return
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = int(prefix.split('.')[2])
# start point of gate_up_proj weight prefetch
if prefix.split('.')[-2] == "self_attn":
forward_context.prefetch_mlp_gate_up_proj = True
if forward_context.prefetch_mlp_gate_up_proj:
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
mlp_gate_up_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_GATE_UP_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.gate_up_proj.weight, \
x_dependency, mlp_gate_up_prefetch_size)
return
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
def _maybe_prefetch_mlp_down_proj_impl(x_dependency: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not forward_context.prefetch_mlp_enabled:
return
forward_context.prefetch_mlp_down_proj = True
model_instance = forward_context.model_instance
prefetch_stream = forward_context.prefetch_stream
layer_idx = forward_context.layer_idx
# start point of down_proj weight prefetch
prefetch_stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(prefetch_stream):
mlp_down_prefetch_size = envs_ascend.VLLM_ASCEND_MLP_DOWN_PREFETCH_SIZE
torch_npu.npu_prefetch(model_instance.model.layers[layer_idx].mlp.down_proj.weight, \
x_dependency, mlp_down_prefetch_size)
forward_context.layer_idx += 1
return
def _maybe_prefetch_mlp_down_proj_impl_fake(
x_dependency: torch.Tensor) -> None:
return
def _maybe_wait_prefetch_done_impl(x: torch.Tensor) -> None:
try:
forward_context = get_forward_context()
except AssertionError:
return
if not forward_context.prefetch_mlp_enabled:
return
if forward_context.prefetch_mlp_gate_up_proj or \
forward_context.prefetch_mlp_down_proj:
prefetch_stream = forward_context.prefetch_stream
# wait until prefetch done
torch.npu.current_stream().wait_stream(prefetch_stream)
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
return
def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return
def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
direct_register_custom_op(op_name="maybe_chunk_residual",
op_func=_maybe_chunk_residual_impl,
fake_impl=lambda x, residual: residual,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_gate_up_proj",
op_func=_maybe_prefetch_mlp_gate_up_proj_impl,
fake_impl=_maybe_prefetch_mlp_gate_up_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_prefetch_mlp_down_proj",
op_func=_maybe_prefetch_mlp_down_proj_impl,
fake_impl=_maybe_prefetch_mlp_down_proj_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_wait_prefetch_done",
op_func=_maybe_wait_prefetch_done_impl,
fake_impl=_maybe_wait_prefetch_done_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@@ -20,6 +20,7 @@ from typing import Optional, Tuple
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@@ -37,34 +38,39 @@ def _rope_forward_oot(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
is_neox_style: bool,
offsets: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, neox_style,
if _custom_rotary_embedding_enabled(query, is_neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if self.rotary_dim < self.head_size:
if self.cos is not None and \
self.sin is not None:
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
# This method requires head_size and rotary_dim equal 128 and neox_style is True
query = query.contiguous().view(1, query.shape[0], -1,
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
@@ -80,25 +86,26 @@ def _rope_forward_oot(
k_rot,
self.head_size,
self.cos_sin_cache,
neox_style,
is_neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
else:
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
is_neox_style,
)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding(RotaryEmbedding):
@@ -112,6 +119,8 @@ class AscendRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
self.cos = None
self.sin = None
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
@@ -123,14 +132,25 @@ class AscendRotaryEmbedding(RotaryEmbedding):
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
):
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
is_neox_style = self.is_neox_style
if is_neox_style_override is not None:
is_neox_style = is_neox_style_override
forward_context = get_forward_context()
is_first_layer = forward_context.is_first_layer
# Generate cos and sin outside layers to avoid repeated calculation.
if is_neox_style and self.head_size == 128 and self.cos_sin_cache.shape[
-1] == 128:
if is_first_layer:
cos_sin = self.cos_sin_cache.index_select(0, positions)
last_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2, last_dim // 2).repeat(
1, 1, 2).chunk(2, dim=-2)
# BSNH
self.cos = cos.view(1, -1, 1, last_dim).contiguous()
self.sin = sin.view(1, -1, 1, last_dim).contiguous()
forward_context.is_first_layer = False
return _rope_forward_oot(self, positions, query, key, is_neox_style,
offsets)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
@@ -168,8 +188,10 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
self._set_cos_sin_cache(seq_len=max_position_embeddings,
# NOTE: For ascend friendly computing, reorder sin and cos cache
self.max_seq_len = math.ceil(max_position_embeddings * scaling_factor)
self._set_cos_sin_cache(self.max_seq_len,
device=NPUPlatform.device_type,
dtype=dtype)
@@ -275,8 +297,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
def _set_cos_sin_cache(self, max_seq_len, device, dtype):
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
@@ -297,9 +318,7 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
t = torch.arange(max_seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
@@ -317,16 +336,13 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
offsets: Optional[torch.Tensor] = None):
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
is_neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2,
@@ -334,6 +350,6 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets)
return q_pe, k_pe

View File

@@ -0,0 +1,384 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import os
from typing import Optional
import torch
from vllm.triton_utils import tl, tldevice, triton
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
div = tldevice.fast_dividef
exp = tldevice.fast_expf
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
@triton.jit
def div_normal(x, y):
return x / y
div = div_normal
exp = tl.exp
log = tl.log
log2 = tl.log2
@triton.heuristics({
'USE_INITIAL_STATE':
lambda args: args['h0'] is not None,
'IS_VARLEN':
lambda args: args['cu_seqlens'] is not None,
"IS_CONTINUOUS_BATCHING":
lambda args: args['ssm_state_indices'] is not None,
"IS_SPEC_DECODING":
lambda args: args['num_accepted_tokens'] is not None,
})
@triton.jit(do_not_specialize=['N', 'T'])
def fused_recurrent_gated_delta_rule_fwd_kernel(
q,
k,
v,
g,
beta,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.constexpr, # num of sequences
T: tl.constexpr, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.
constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_init_state_token
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
if IS_BETA_HEADWISE:
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
else:
p_beta = beta + bos * HV + i_hv + HV * i_t
p_g = g + bos * HV + i_hv + HV * i_t
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
# b_h *= tl.exp(b_g)
b_h *= exp(b_g)
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta
# [BK, BV]
b_h += b_k[:, None] * b_v[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
i_t).to(tl.int64) * stride_final_state_token
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
# print("N: ", N)
# print("T: ", T)
# print("B: ", B)
# print("H: ", H)
# print("HV: ", HV)
# print("K: ", K)
# print("V: ", V)
# print("BK: ", BK)
# print("BV: ", BV)
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
if scale is None:
scale = k.shape[-1]**-0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
inplace_final_state,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
@@ -252,3 +253,16 @@ class AscendLogitsProcessor(LogitsProcessor):
logits = logits[..., :self.org_vocab_size]
return logits
def forward(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
# keep this for version compatibility
sampling_metadata=None, # type: ignore
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
return LogitsProcessor.forward(self,
lm_head,
hidden_states,
embedding_bias=embedding_bias)

View File

@@ -46,6 +46,27 @@
# Need a PR to vllm to support get port from environment.
# Future Plan:
# Remove those patch when vllm merged them
# 2. `torch.distributed.all_reduce`, `torch.distributed.broadcast`
# Why:
# tensor alignment for 310p
# How
# rewrite all_reduce and broadcast in torch.distributed
# Related PR (if no, explain why):
# No, not ready yet.
# Future Plan:
# Find a better way to support tensor alignment for 310p without this patch.
#
# ** File: platform/patch_common/patch_multimodal_merge.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.utils._merge_multimodal_embeddings`
# Why:
# '_merge_multimodal_embeddings' func of vllm is incompatible with Ascend.
# How
# Replace with CPU operation that can be executed asynchronously.
# Related PR (if no, explain why):
# This is a bug by Ascend only. It can' be fixed in vLLM.
# Future Plan:
# Identify this pattern in torch-npu and remove this patch.
#
# * Worker Patch:
# ===============
@@ -86,19 +107,15 @@
# - https://github.com/vllm-project/vllm/pull/21591
# Future Plan:
# Revert it when vLLM merge #21591 and release new version
# ** File: worker/patch_common/patch_linear.py **
# ** File: worker/patch_common/patch_logits.py **
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.linear.RowParallelLinear`
# 1. `vllm._custom_ops.apply_repetition_penalties`
# Why:
# We need to fuse matmul and allreuce in `RowParallelLinear`
# to improve performance.
# apply_repetition_penalties in vLLM use tensor.is_cuda to check if tensor is on cuda. But the value is always True
# on ascend, thus we need to patch apply_repetition_penalties.
# How
# Create a new class `AscendRowParallelLinear` that inherits from `RowParallelLinear`.
# In this class, we override the `forward` method to use
# torch_npu.npu_mm_all_reduce_base to replace matmul and allreduce.
# Remove the related cuda check in apply_repetition_penalties.
# Related PR (if no, explain why):
# - https://github.com/vllm-project/vllm-ascend/pull/1926
# - this is a bug by Ascend only. It can' be fixed in vLLM.
# Future Plan:
# Validate more models in all kinds of scenario,
# if performance is always improved, we can enable this patch by default and remove the env
# variable `VLLM_ASCEND_ENABLE_FUSE_MATMUL_ALLREDUCE` in the future.
# Fix this bug in torch-npu, bump torch-npu version and remove this patch.

View File

@@ -15,4 +15,10 @@
# limitations under the License.
#
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa

View File

@@ -0,0 +1,313 @@
import ast
import vllm.envs as envs
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import logger
# mypy: ignore-errors
@property
def is_deepseek_mla(self: ModelConfig):
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
'kimi_k2', 'longcat_flash', 'deepseek_v32'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
# underlying architecture
return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
and self.hf_text_config.kv_lora_rank is not None
return False
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({
"n_predict": n_predict,
"architectures": ["LongCatFlashMTPModel"]
})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
"qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
raise ValueError("num_speculative_tokens was provided but without "
"speculative model.")
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
# TODO: Move this import to the top once `ModelConfig`
# lives in `vllm.config.model`.
from vllm.config import ModelConfig
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs import SpeculatorsConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle")
self.draft_model_config.hf_config = eagle_config
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
self.draft_model_config.hf_config.num_lookahead_tokens = \
self.num_speculative_tokens
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
))
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config,
self.draft_tensor_parallel_size))
ModelConfig.is_deepseek_mla = is_deepseek_mla
SpeculativeConfig.__post_init__ = __post_init__
SpeculativeConfig.hf_config_override = hf_config_override

View File

@@ -0,0 +1,100 @@
# mypy: ignore-errors
import vllm.model_executor.models.config
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.ascend_config import get_ascend_config
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
logger = init_logger(__name__)
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
ascend_config = get_ascend_config()
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
block_size=model_config.max_model_len,
).page_size_bytes
block_alignment_bytes = 64
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = block_alignment_bytes * cdiv(
mamba_page_size, block_alignment_bytes * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if (cache_config.block_size is None
or cache_config.block_size < attn_block_size):
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size)
# compute new attention page size
attn_page_size = \
cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size):
cache_config.mamba_page_size_padded = (attn_page_size)
mamba_padding_pct = 100 * (attn_page_size -
mamba_page_size) / mamba_page_size
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct)
vllm.model_executor.models.config.HybridAttentionMambaModelConfig.verify_and_update_config = verify_and_update_config

View File

@@ -0,0 +1,58 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import torch
import vllm
from vllm.model_executor.models.utils import (_embedding_count_expression,
_flatten_embeddings)
from vllm.multimodal import NestedTensors
def _merge_multimodal_embeddings(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
try:
inputs_embeds[is_multimodal] = flattened
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings

View File

@@ -0,0 +1,200 @@
import vllm.transformers_utils.configs
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from vllm.transformers_utils import config
logger = logging.get_logger(__name__)
class DeepseekV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
ep_size=1,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method='noaux_tc',
n_group=8,
topk_group=4,
num_experts_per_tok=8,
moe_layer_freq=1,
first_k_dense_replace=3,
norm_topk_prob=True,
scoring_func='sigmoid',
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
vllm.transformers_utils.configs.__all__.append("DeepseekV3Config")
vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config
config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config"

View File

@@ -15,8 +15,18 @@
# limitations under the License.
#
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import vllm_ascend.patch.worker.patch_common.patch_triton
# isort: off
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_linear # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
import vllm_ascend.patch.worker.patch_common.patch_lora_embedding # noqa
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
# TODO: revert me when triton import is fixed
# import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa

View File

@@ -0,0 +1,202 @@
from typing import List, Optional
import torch
import vllm
import vllm.envs as envs
from torch import nn
from vllm.attention import Attention, AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm_ascend.utils import vllm_version_is
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sfa: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
nn.Module.__init__(self)
AttentionLayerBase.__init__(self)
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
if vllm_version_is("0.10.2"):
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=self.has_sink)
else:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=self.has_sink)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
self.query_quant = None
vllm.attention.Attention = AscendAttention

View File

@@ -0,0 +1,181 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors
from functools import cache
from typing import Optional
import torch
import vllm
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.platforms import _Backend, current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.10.2"):
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=has_sink,
)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
from vllm.attention.backends.placeholder_attn import \
PlaceholderAttentionBackend
return PlaceholderAttentionBackend
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
use_v1, use_mla, use_sfa, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
else:
def get_attn_backend( # type: ignore[misc]
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=has_sink,
)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
use_v1: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
use_v1, use_mla, use_sfa, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
vllm.attention.get_attn_backend = get_attn_backend
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend

View File

@@ -0,0 +1,110 @@
from dataclasses import dataclass, fields
from typing import Optional
import torch
import vllm
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.utils import cdiv, get_dtype_size
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
spec_manager_map)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
use_sfa: bool
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
sfa_bytes = 128 * self.block_size * get_dtype_size(
self.dtype) if self.use_sfa else 0
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) + sfa_bytes
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
@dataclass(frozen=True)
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
use_sfa=specs[0].use_sfa,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec

View File

@@ -1,147 +0,0 @@
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Union
import torch
import torch_npu
import vllm
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from vllm.distributed import (get_tensor_model_parallel_rank,
split_tensor_along_last_dim)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import logger
from vllm.model_executor.layers.linear import RowParallelLinear
import vllm_ascend.envs as envs_ascend
_HCOMM_INFO = None
class AscendRowParallelLinear(RowParallelLinear):
"""
AscendRowParallelLinear is a custom implementation of RowParallelLinear
that overrides the forward method to handle Ascend-specific operations.
"""
def __init__(self, *args, **kwargs):
"""Initialize the AscendRowParallelLinear layer.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
tp_group = get_tp_group().device_group
hcomm_info = self.get_hcomm_info(tp_group)
self.hcomm_info = hcomm_info
super().__init__(*args, **kwargs)
self.weight_t = self.weight.t()
@staticmethod
def get_hcomm_info(group: ProcessGroup) -> str:
"""Get the HCCL communication information for the given group.
Args:
group (ProcessGroup): The process group for which to get the HCCL communication info.
Returns:
str: The HCCL communication name for the given group.
"""
global _HCOMM_INFO
if _HCOMM_INFO is not None:
return _HCOMM_INFO
rank = torch.distributed.get_rank(group)
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
_HCOMM_INFO = group._get_backend(
torch.device("npu")).get_hccl_comm_name(global_rank)
else:
_HCOMM_INFO = group.get_hccl_comm_name(rank)
return _HCOMM_INFO
def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
"""Forward pass for the AscendRowParallelLinear layer.
Args:
input_ (torch.Tensor): the input tensor to the layer.
Returns:
Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
The output tensor after applying the linear transformation,
and optionally the bias if `return_bias` is True.
"""
input_parallel = self.calc_input(input_)
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
output = self.calc_output(input_parallel)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def calc_input(self, input_: torch.Tensor) -> torch.Tensor:
"""Calculate the input tensor for parallel processing.
Args:
input_ (torch.Tensor): the input tensor to be processed.
Returns:
torch.Tensor: The input tensor split along the last dimension
for tensor model parallelism, or the original input if not parallel.
"""
if self.input_is_parallel:
return input_
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
return splitted_input[tp_rank].contiguous()
def calc_output(self, input_parallel: torch.Tensor) -> torch.Tensor:
"""Calculate the output tensor of forward by considering
fusing communication and computation.
Args:
input_parallel (_type_): the input tensor to be processed in parallel.
Returns:
torch.Tensor: the output tensor after applying the linear transformation
and optionally handle communication between tensor model parallel ranks.
"""
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
if self.reduce_results and self.tp_size > 1:
output = torch_npu.npu_mm_all_reduce_base(input_parallel,
self.weight_t,
self.hcomm_info,
bias=bias_)
else:
output = self.quant_method.apply(self, input_parallel, bias=bias_)
return output
if envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE:
logger.info("AscendRowParallelLinear: Matmul all-reduce is enabled. ")
vllm.model_executor.layers.linear.RowParallelLinear = AscendRowParallelLinear

View File

@@ -1,29 +0,0 @@
from typing import Optional
import vllm
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.lora.layers import VocabParallelEmbeddingWithLoRA
from vllm.lora.utils import _all_lora_classes
from vllm_ascend.ops.vocab_parallel_embedding import \
AscendVocabParallelEmbedding
class AscendVocabParallelEmbeddingWithLoRA(VocabParallelEmbeddingWithLoRA):
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is AscendVocabParallelEmbedding
# Patch for lora register_model issue after overriding VocabParallelEmbedding class (#2515)
_all_lora_classes.add(AscendVocabParallelEmbeddingWithLoRA)
vllm.lora.utils._all_lora_classes = _all_lora_classes

View File

@@ -0,0 +1,16 @@
import vllm.model_executor.layers.fla.ops.chunk
import vllm.model_executor.layers.fla.ops.fused_recurrent
import vllm.model_executor.layers.fla.ops.layernorm_guard
import vllm.model_executor.layers.mamba.ops.causal_conv1d
from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn,
causal_conv1d_update_npu)
from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule
from vllm_ascend.ops.sigmoid_gating import \
fused_recurrent_gated_delta_rule_fwd_kernel
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule

View File

@@ -0,0 +1,44 @@
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import GiB_bytes
from vllm_ascend.utils import vllm_version_is
logger = init_logger(__name__)
def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
# This method creates unquantized linear weights.
# The weights are not quantized, and they are not sharded.
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
try:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
except torch.cuda.OutOfMemoryError as e:
logger.error("Failed to create unquantized linear weights: %s", e)
if torch.cuda.is_available():
logger.debug("CUDA device: %s", torch.cuda.current_device())
logger.debug("Allocated: %.2f GiB",
torch.cuda.memory_allocated() / GiB_bytes)
logger.debug("Reserved: %.2f GiB",
torch.cuda.memory_reserved() / GiB_bytes)
raise RuntimeError(
"Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate "
"the weight.") from e
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
if not vllm_version_is("0.10.2"):
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
UnquantizedLinearMethod.create_weights = create_weights

View File

@@ -16,6 +16,7 @@
#
import gc
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
@@ -31,7 +32,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
update_aclgraph_sizes)
update_aclgraph_sizes, vllm_version_is)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -128,11 +129,43 @@ class NPUPlatform(Platform):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
ascend_scheduler_config = ascend_config.ascend_scheduler_config
if vllm_version_is("0.10.2"):
structured_outputs_config = vllm_config.decoding_config
else:
structured_outputs_config = vllm_config.structured_outputs_config
if (model_config is not None and not model_config.use_mla
and not scheduler_config.async_scheduling):
logger.info(
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
"as the performance of operators supporting this feature "
"functionality is currently suboptimal.")
if not model_config.is_multimodal_model and \
structured_outputs_config.backend == "auto" and \
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
not scheduler_config.send_delta_data and \
scheduler_config.policy == "fcfs":
ascend_scheduler_config.enabled = True
chunked_prefill_enabled_in_ascend_scheduler = getattr(
ascend_scheduler_config, "enable_chunked_prefill", False)
if chunked_prefill_enabled_in_ascend_scheduler:
logger.warning(
"Chunked prefill feature is enabled in ascend_scheduler,"
"but note that the operator supporting this feature "
"would lead to performance degradation.")
# In this situation, max_num_batched_tokens would have been rewritten.
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
if (scheduler_config.max_num_batched_tokens
< scheduler_config.max_model_len
and not chunked_prefill_enabled_in_ascend_scheduler):
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
kv_cache_dtype = vllm_config.additional_config.get(
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
vllm_config.cache_config.cache_dtype = kv_cache_dtype
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
@@ -148,23 +181,13 @@ class NPUPlatform(Platform):
compilation_config.cudagraph_num_of_warmups = 1
# TODO: make vllm support oot platform to set `compilation_config.cudagraph_mode`
# if cudagraph_mode is not explicitly set by users, set default value
if compilation_config.level == CompilationLevel.PIECEWISE:
compilation_config.cudagraph_mode = \
CUDAGraphMode.PIECEWISE
elif compilation_config.level not in [
if compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
else:
logger.warning(
"compilation_config.level = CompilationLevel.NO_COMPILATION is set, Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
@@ -185,18 +208,22 @@ class NPUPlatform(Platform):
"and use_cached_kv_cache_bytes in torchair_graph_config.")
delete_torchair_cache_file()
if parallel_config.distributed_executor_backend == "ray":
logger.warning(
"Ray distributed executor backend is not compatible with ACL Graph mode "
"right now. Setting CUDAGraphMode to NONE")
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes()
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
if not vllm_version_is("0.10.2"):
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
# TODO: Currently MLA does not support FULL_DECODE_ONLY, remove the second condition
# after MLA being supported
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE or (
compilation_config.cudagraph_mode
== CUDAGraphMode.FULL_DECODE_ONLY and model_config is not None
and model_config.use_mla):
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
@@ -204,9 +231,28 @@ class NPUPlatform(Platform):
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1()
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend(
["vllm.unified_ascend_attention_with_output"])
compilation_config.splitting_ops.extend([
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
@@ -215,7 +261,9 @@ class NPUPlatform(Platform):
compilation_config.level = CompilationLevel.NO_COMPILATION
if parallel_config and parallel_config.worker_cls == "auto":
if ascend_config.torchair_graph_config.enabled:
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled or ascend_config.enable_shared_expert_dp:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
@@ -223,6 +271,7 @@ class NPUPlatform(Platform):
if cache_config:
if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
logger.warning(
"If prefix caching is enabled, block size must be set to 128."
@@ -242,12 +291,6 @@ class NPUPlatform(Platform):
ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config
if compilation_config.pass_config.enable_sequence_parallelism:
if not parallel_config.enable_expert_parallel or vllm_config.model_config.hf_config.model_type != "qwen3_moe":
raise NotImplementedError(
"For better performance in Qwen3 MoE, SP only works exclusively with MC2, AllToAll, and AllToAllV."
)
@classmethod
def get_attn_backend_cls(cls,
selected_backend,
@@ -257,27 +300,40 @@ class NPUPlatform(Platform):
block_size,
use_v1,
use_mla,
use_sfa,
has_sink=False):
if not use_v1:
raise ValueError("vLLM Ascend does not support V0 engine.")
use_torchair = get_ascend_config().torchair_graph_config.enabled
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and not use_sfa:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sfa:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
use_torchair = ascend_config.torchair_graph_config.enabled
# choose attention backend based on use_mla and use_torchair
backend_map = {
(True, True):
(True, False, True):
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
(True, False):
(True, False, False):
"vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, True):
(False, False, True):
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
(False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
(False, False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True, False):
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
(True, True, True):
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
}
return backend_map[(use_mla, use_torchair)]
return backend_map[(use_mla, use_sfa, use_torchair)]
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
@classmethod
def get_current_memory_usage(cls,
@@ -343,3 +399,11 @@ class NPUPlatform(Platform):
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True

View File

@@ -1,184 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
# func refers to vocabParallelEmbedding.__init__
def wrapper_vocab_parallel_embedding_init(func):
def init(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
func(
self,
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
return init
# func refers to RMSNorm.__init__
def wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def wrapper_rmsnorm_forward_oot(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not self.ignore_anti:
if residual is not None:
residual += x
out = torch_npu._npu_quant_rms_norm(
residual,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out, residual
out = torch_npu._npu_quant_rms_norm(
x,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out
if residual is not None:
x, residual = func(self, x, residual)
return x.add_(self.bias), residual
return func(self, x).add_(self.bias)
return _rmsnorm_forward_oot
MODEL_LAYER_MAPPING = {
"LlamaModel": {
"attn": {
"layer_attr": "self_attn",
"proj_attr": "qkv_proj",
"norm_attr": "input_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
"mlp": {
"layer_attr": "mlp",
"proj_attr": "gate_up_proj",
"norm_attr": "post_attention_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
},
}
def wrapper_load_model(func):
def postprocess_loading(self) -> None:
func(self)
def process_layer(layer, idx, mapping):
def process_module(module_cfg, layer_obj):
if module_cfg is None:
return
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
if module_obj is None:
return
proj_attr = module_cfg["proj_attr"]
if callable(proj_attr):
proj = proj_attr(module_obj, idx)
else:
proj = getattr(module_obj, proj_attr, None)
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
if proj is None or norm is None:
return
norm.ignore_anti = isinstance(proj.quant_method,
module_cfg["unquantized_type"])
if not norm.ignore_anti:
for param_name in ["input_scale", "input_offset"]:
if hasattr(proj, param_name):
param = getattr(proj, param_name)
norm.register_parameter(
param_name,
torch.nn.Parameter(param.clone(),
requires_grad=False))
process_module(mapping.get("attn"), layer)
process_module(mapping.get("mlp"), layer)
model_type = self.model.model.__class__.__name__
mapping = MODEL_LAYER_MAPPING.get(model_type)
if not mapping:
logger.info(
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
)
return
for idx, layer in enumerate(self.model.model.layers):
process_layer(layer, idx, mapping)
if isinstance(self.model.model.norm, RMSNorm):
self.model.model.norm.ignore_anti = True
return postprocess_loading

Some files were not shown because too many files have changed in this diff Show More