v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

@@ -0,0 +1,56 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
import vllm_ascend.ops.common_fused_moe # noqa
import vllm_ascend.ops.fused_moe # noqa
import vllm_ascend.ops.layernorm # noqa
import vllm_ascend.ops.vocab_parallel_embedding # noqa
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding)
class dummyFusionOp:
default = None
def __init__(self, name=""):
self.name = name
def register_dummy_fusion_op() -> None:
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
name="static_scaled_fp8_quant")
torch.ops._C.dynamic_scaled_fp8_quant = dummyFusionOp(
name="dynamic_scaled_fp8_quant")
torch.ops._C.dynamic_per_token_scaled_fp8_quant = dummyFusionOp(
name="dynamic_per_token_scaled_fp8_quant")
torch.ops._C.rms_norm_static_fp8_quant = dummyFusionOp(
name="rms_norm_static_fp8_quant")
torch.ops._C.fused_add_rms_norm_static_fp8_quant = dummyFusionOp(
name="fused_add_rms_norm_static_fp8_quant")
torch.ops._C.rms_norm_dynamic_per_token_quant = dummyFusionOp(
name="rms_norm_dynamic_per_token_quant")
__all__ = [
"AscendQuickGELU", "AscendSiluAndMul", "AscendRotaryEmbedding",
"AscendDeepseekScalingRotaryEmbedding"
]

View File

@@ -0,0 +1,42 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
class AscendQuickGELU(QuickGELU):
def forward_oot(self, x: torch.tensor) -> torch.Tensor:
import torch_npu
out = torch_npu.npu_fast_gelu(x)
return out
class AscendSiluAndMul(SiluAndMul):
def forward_oot(self, x: torch.Tensor) -> torch.Tensor:
import torch_npu
from vllm_ascend.utils import is_310p
if is_310p():
out = torch_npu.npu_swiglu(x.to(torch.float32)).to(torch.float16)
else:
out = torch_npu.npu_swiglu(x)
return out

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
# Implementation of vanilla chunked prefill, should be removed after the kernel is ready for
# all the corner case
def vanilla_chunked_prefill(
output: torch.Tensor,
query: torch.Tensor, # (num_tokens, heads, head_size)
key_cache: torch.Tensor, # (num_blocks, block_size, kv_heads, head_size)
value_cache: torch.
Tensor, # (num_blocks, block_size, kv_heads, head_size,)
block_tables: torch.Tensor, # (num_seqs, max_num_blocks_per_seq)
cu_seqlen_q: torch.Tensor, # (num_seqs + 1,)
cu_seqlen_k: torch.Tensor, # (num_seqs + 1,)
max_seqlen_q: int,
max_seqlen_k: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool = True,
) -> torch.Tensor:
num_query_heads = query.shape[1]
head_dim = value_cache.shape[3]
num_kv_heads = value_cache.shape[2]
block_size = value_cache.shape[1]
num_batch = cu_seqlen_q.shape[0] - 1
max_num_blocks_per_seq = block_tables.shape[1]
key = key_cache[block_tables].view(num_batch,
max_num_blocks_per_seq * block_size,
num_kv_heads, head_dim)
value = value_cache[block_tables].view(num_batch,
max_num_blocks_per_seq * block_size,
num_kv_heads, head_dim)
key = key[:, :max_seqlen_k, :, :]
value = value[:, :max_seqlen_k, :, :]
seqlen_k = cu_seqlen_k[1:] - cu_seqlen_k[:-1]
seqlen_q = cu_seqlen_q[1:] - cu_seqlen_q[:-1]
seqlen_q = seqlen_q.view(-1, 1)
seqlen_k = seqlen_k.view(-1, 1)
seqlen_diff = seqlen_k - seqlen_q
q_idx_mask = (torch.arange(0, max_seqlen_q,
device="npu").view(1, -1).repeat(num_batch, 1))
k_idx_mask = (torch.arange(0, max_seqlen_k,
device="npu").view(1, -1).repeat(num_batch, 1))
q_mask = q_idx_mask < seqlen_q
k_mask = k_idx_mask < seqlen_k
# calculate idx for causal mask of query [batch, max_seqlen_q]
causal_mask_idx = (q_idx_mask + seqlen_diff)[q_mask]
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
tril_mask = torch.tril(torch.ones(max_seqlen_k, max_seqlen_k,
device="npu"))
tril_mask[tril_mask == 0] = float("-inf")
tril_mask[tril_mask == 1] = 0
causal_mask = tril_mask[causal_mask_idx]
causal_mask_padding = torch.empty([num_batch, max_seqlen_q, max_seqlen_k],
device="npu").fill_(float("-inf"))
causal_mask_padding[q_mask] = causal_mask
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
causal_mask_padding = causal_mask_padding.unsqueeze(1)
pad_q = torch.zeros(
[num_batch, max_seqlen_q, num_query_heads, head_dim],
device="npu",
dtype=query.dtype,
)
pad_k = torch.zeros(
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
device="npu",
dtype=key.dtype,
)
pad_v = torch.zeros(
[num_batch, max_seqlen_k, num_kv_heads, head_dim],
device="npu",
dtype=value.dtype,
)
pad_q[q_mask] = query
pad_k[k_mask] = key[k_mask]
pad_v[k_mask] = value[k_mask]
if num_query_heads > num_kv_heads:
pad_k = pad_k.view(
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
pad_k = pad_k.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
[num_batch, max_seqlen_k, num_query_heads, head_dim])
pad_v = pad_v.view(
[num_batch, max_seqlen_k, num_kv_heads, 1, head_dim])
pad_v = pad_v.repeat(1, 1, 1, num_query_heads // num_kv_heads, 1).view(
[num_batch, max_seqlen_k, num_query_heads, head_dim])
# permute to [b, h, n, k]
pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
attn_mask = torch.empty([num_batch, 1, 1, max_seqlen_k],
device="npu").fill_(float("-inf"))
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)
# [b, h, f, t]
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
attn_weights *= scale
attn_mask = attn_mask.float()
attn_weights = attn_weights + attn_mask
if causal:
attn_weights = attn_weights + causal_mask_padding
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = (attn_output[q_mask].view([-1, num_query_heads,
head_dim]).to(output.dtype))
output.copy_(attn_output)
return attn_output
def vanilla_chunked_prefill_mla(
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
kv_cache: Tuple[
torch.Tensor], # [nope, rope] (num_blocks, block_size, latent_kv)
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
query_lens: torch.Tensor, # (batch_size)
context_lens: torch.Tensor, # (batch_size)
kv_b_proj: ColumnParallelLinear, # ()
max_query_len: int,
max_context_len: int,
nope_dim: int,
rope_dim: int,
v_head_dim: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
causal: bool = True) -> None:
batch_size = block_tables.size(0)
assert len(kv_cache) > 1
assert query_lens.size(0) == batch_size
num_heads = query.size(1)
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
block_size = nope_cache.size(1)
latent_kv_dim = nope_cache.size(-1)
max_num_blocks_per_seq = block_tables.size(1)
batch_size = query_lens.size(0)
nope_cache = nope_cache.squeeze()
# select kv_c out as [batch_size, max_context_len, latent_kv + rope_dim] and get kv_c and k_pe
# cached_kv_c: [batch_size, max_context_len, latent_kv]
# cached_k_pe: [batch_size, max_context_len, rope_dim]
cache_kv_c = nope_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
latent_kv_dim)[:, :max_context_len, :]
cache_k_pe = rope_cache[block_tables].view(
batch_size, max_num_blocks_per_seq * block_size,
rope_dim)[:, :max_context_len, :]
# get k_rope and v
# k_nope: [batch_size, max_context_len, num_heads, nope_dim]
# value: [batch_size, max_context_len, num_heads, v_head_dim]
k_nope, value = kv_b_proj(cache_kv_c)[0].view(
batch_size, max_context_len, num_heads,
nope_dim + v_head_dim).split([nope_dim, v_head_dim], dim=-1)
# key: [batch_size, max_context_len, num_hads, rope_dim + nope_dim]
key = torch.cat(
[k_nope, cache_k_pe.unsqueeze(2).expand(-1, -1, num_heads, -1)],
dim=-1)
context_lens = context_lens.view(-1, 1).to("npu")
query_lens = query_lens.view(-1, 1).to("npu")
seq_diff = context_lens - query_lens
q_idx_mask = (torch.arange(0, max_query_len,
device="npu").view(1, -1).repeat(batch_size, 1))
kv_c_idx_mask = (torch.arange(0, max_context_len,
device="npu").view(1,
-1).repeat(batch_size, 1))
kv_c_mask = kv_c_idx_mask < context_lens
q_mask = q_idx_mask < query_lens
# calculate idx for causal mask of query [batch, max_seqlen_q]
causal_mask_idx = (q_idx_mask + seq_diff)[q_mask]
# generate causal mask [batch, max_seqlen_q, max_seqlen_k]
tril_mask = torch.tril(
torch.ones(max_context_len, max_context_len, device="npu"))
tril_mask[tril_mask == 0] = float("-inf")
tril_mask[tril_mask == 1] = 0
causal_mask = tril_mask[causal_mask_idx]
causal_mask_padding = torch.empty(
[batch_size, max_query_len, max_context_len],
device="npu").fill_(float("-inf"))
causal_mask_padding[q_mask] = causal_mask
# to [batch, num_heads, max_seqlen_q, max_seqlen_k]
causal_mask_padding = causal_mask_padding.unsqueeze(1)
pad_q = torch.zeros(
[batch_size, max_query_len, num_heads, rope_dim + nope_dim],
device="npu",
dtype=query.dtype,
)
pad_k = torch.zeros(
[batch_size, max_context_len, num_heads, rope_dim + nope_dim],
device="npu",
dtype=key.dtype,
)
pad_v = torch.zeros(
[batch_size, max_context_len, num_heads, v_head_dim],
device="npu",
dtype=value.dtype,
)
num_query = torch.sum(q_mask).item()
num_add_query = num_query - query.size(0)
# mtp will come in
if num_add_query > 0:
add_query_size = query.size()
add_query_size = list(add_query_size)
add_query_size[0] = num_add_query
pad_tensor = torch.zeros(add_query_size,
dtype=query.dtype,
device=query.device)
query = torch.cat([query, pad_tensor], dim=0)
pad_q[q_mask] = query
pad_k[kv_c_mask] = key[kv_c_mask]
pad_v[kv_c_mask] = value[kv_c_mask]
pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
attn_mask = torch.empty([batch_size, 1, 1, max_context_len],
device="npu").fill_(float("-inf"))
attn_mask[:, :, :, :max_context_len].masked_fill_(
kv_c_mask[:, None, None, :], 0)
# [b, h, f, t]
attn_weights = torch.einsum("bhqd,bhkd->bhqk", pad_q, pad_k)
attn_weights *= scale
attn_mask = attn_mask.float()
attn_weights = attn_weights + attn_mask
if causal:
attn_weights = attn_weights + causal_mask_padding
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bhkd->bhqd", attn_weights, pad_v.float())
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = (attn_output[q_mask].view([-1, num_heads,
v_head_dim]).to(output.dtype))
attn_output = attn_output.view_as(output)
output.copy_(attn_output)
return attn_output
def vanilla_decode_mla(
query: torch.Tensor, # [num_tokens, num_heads, latent_dim + rope_dim]
key_cache: torch.
Tensor, # [num_blocks, block_size, num_kv_heads, latent_dim + rope_dim]
num_kv_heads: int,
num_heads: int,
scale: float,
block_table: torch.Tensor, # [batch_size, max_block_size]
context_lens: List[int],
mla_vhead_size: int,
rope_dim: int,
output: torch.Tensor):
batch_size = block_table.size()[0]
max_block_size = block_table.size()[1]
reduce_dim = key_cache.size()[-1]
block_size = key_cache.size()[1]
latent_dim = reduce_dim - rope_dim
kv_c_and_pe = key_cache[block_table].view(
[batch_size, max_block_size * block_size, num_kv_heads, reduce_dim])
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, device="npu").view(batch_size, 1)
# [batch_size, max_context_len, num_kv_heads, latent_dim + rope_dim]
# since the kv head is 1 in deepseek, we use expand here for perf
kv_c_and_pe = kv_c_and_pe[:, :max_context_len, :, :].expand(
-1, -1, num_heads, 1)
kv_c = kv_c_and_pe[..., :latent_dim]
kv_idx_mask = (torch.arange(0, max_context_len,
device="npu").view(1,
-1).repeat(batch_size, 1))
# [batch_size, max_context_len]
kv_idx_mask = kv_idx_mask < context_lens
query = query.unsqueeze(1)
attn_weights = torch.einsum("bqhd,bkhd->bhqk", query, kv_c_and_pe)
attn_weights *= scale
attn_weights = attn_weights + kv_idx_mask[:, -1, -1, :].float()
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.einsum("bhqk,bkhd->bqhd", attn_weights,
kv_c.float()).view(-1, num_heads, latent_dim)
output.copy_(attn_output)
return output

View File

@@ -0,0 +1,62 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import torch
import torch.distributed
import torch.distributed as dist
import torch_npu
COMM_STREAM = None
def async_all_to_all(input_,
output_split_sizes,
input_split_sizes,
group,
event=None):
if output_split_sizes is None:
# Equal split (all2all)
a2a_out = torch.empty_like(input_)
else:
# Unequal split (all2all-v)
a2a_out = input_.new_empty(
size=[sum(output_split_sizes)] + list(input_.size()[1:]),
dtype=input_.dtype,
device=torch.npu.current_device(),
)
if event:
# multi stream wait event
global COMM_STREAM
if COMM_STREAM is None:
COMM_STREAM = torch_npu.npu.Stream(
device=torch.npu.current_device())
with torch_npu.npu.stream(COMM_STREAM):
event.wait()
handle = dist.all_to_all_single(
a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
else:
handle = dist.all_to_all_single(a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
return input_, a2a_out, handle

View File

@@ -0,0 +1,531 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Optional
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
if (use_int8_w8a8 or use_int4_w4a8):
assert w1_scale is not None and w2_scale is not None, \
"INT8 quantization requires weight scales."
w1_scale = w1_scale.to(torch.float32)
down_scale = [w2_scale]
down_output_dtype = w2_scale.dtype
else:
down_scale = None
down_output_dtype = None
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
use_int8_w8a8 or use_int4_w4a8)
gate_up_output = torch_npu.npu_grouped_matmul(
x=[permuted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32 if use_int8_w8a8 else None,
)[0]
if (use_int8_w8a8 or use_int4_w4a8):
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_output,
weight_scale=w1_scale,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
activated_output_scale = [activated_output_scale]
else:
activated_output = torch_npu.npu_swiglu(gate_up_output)
activated_output_scale = None
down_output = torch_npu.npu_grouped_matmul(
x=[activated_output],
weight=[w2],
scale=down_scale,
per_token_scale=activated_output_scale,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=down_output_dtype,
)[0]
moe_comm_method.unpermute(down_output, hidden_states)
return hidden_states
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
def forward_oot_v01011(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=1.0,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendFusedMoE(FusedMoE):
def __init__(
self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_fator: float = 1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
else:
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
routed_scaling_fator,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
# TODO: Can we refactor this logic to model_runner?
# TODO: Adjusted logic to differentiate between A2 and A3, we check ep_size here since mc2 only support ep_size >= 16 on A3 now
if self.moe_config.ep_size < 16:
moe_comm_method_name = "allgathercommimpl"
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
return final_hidden_states
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
else:
UnquantizedFusedMoEMethod.forward_oot = forward_oot

View File

@@ -0,0 +1,99 @@
import json
import random
from typing import Dict, List
import torch
class ExpertLoadBalancer(object):
def __init__(self, expert_map_path, global_expert_num):
self.expert_map_path = expert_map_path
self.global_expert_num = global_expert_num
self.expert_map_tensor, self.layers_num, self.ranks_num = (
self._expert_file_to_tensor())
def _expert_file_to_tensor(self):
with open(self.expert_map_path, "r") as f:
data = json.load(f)
layers_num = data["moe_layer_count"]
gpus_num = data["layer_list"][0]["device_count"]
tensor_data = []
for layer in data["layer_list"]:
device_data = []
for device in layer["device_list"]:
device_data.append(device["device_expert"])
tensor_data.append(device_data)
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
return expert_map_tensor, layers_num, gpus_num
def generate_index_dicts(self, tensor_2d):
dict_list = []
current_idx = 0
for row in tensor_2d:
value_to_index = {}
for i in range(row.size(0)):
value = row[i].item()
value_to_index[value] = current_idx + i
dict_list.append(value_to_index)
current_idx += row.size(0)
return dict_list
def generate_expert_placement_map(self):
expert_placement_map = torch.full(
(self.layers_num, self.ranks_num, self.global_expert_num),
-1,
dtype=torch.int32,
)
for layer_id in range(self.layers_num):
for gpu_id in range(self.ranks_num):
e_ids = self.expert_map_tensor[layer_id, gpu_id]
expert_placement_map[layer_id, gpu_id,
e_ids] = torch.arange(len(e_ids),
dtype=torch.int32)
return expert_placement_map
def generate_log2phy_expert_map(self, layer_id):
concatenated = torch.flatten(self.expert_map_tensor[layer_id])
rank_expert_to_global = self.generate_index_dicts(
self.expert_map_tensor[layer_id])
result_dict: Dict[int, List[int]] = {}
for idx, value in enumerate(concatenated):
key = value.item()
if key not in result_dict:
result_dict[key] = []
result_dict[key].append(idx)
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
-1,
dtype=torch.int32)
for rank in range(self.ranks_num):
for key in result_dict:
indices_in_concat = result_dict[key]
if key in rank_expert_to_global[rank]:
log2phy_map[rank][key] = rank_expert_to_global[rank][key]
else:
chosen_index = random.choice(indices_in_concat)
log2phy_map[rank][key] = chosen_index
return log2phy_map
def get_rank_placement_map(self, layer_id, rank_id):
expert_placement_map = self.generate_expert_placement_map()
layer_expert_map = expert_placement_map[layer_id]
rank_expert_map = layer_expert_map[rank_id].to(
torch.npu.current_device())
rank_local_expert_num = torch.sum(torch.ne(rank_expert_map, -1)).item()
return rank_local_expert_num, rank_expert_map
def get_rank_log2phy_map(self, layer_id, rank_id):
layer_log2phy_map = self.generate_log2phy_expert_map(layer_id)
return layer_log2phy_map[rank_id]
def get_global_redundant_expert_num(self):
global_redundant_expert_num = (
len(self.expert_map_tensor[0][0]) * self.ranks_num -
self.global_expert_num)
return global_redundant_expert_num

View File

@@ -0,0 +1,587 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Any, Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEConfig # isort: skip
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
get_all_reduce_merge_state,
get_rm_router_logits_state, is_310p)
def unified_fused_experts_eager(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
w1_scale: Optional[torch.Tensor] = None,
w1_scale_bias: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w2_scale_bias: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[Any] = None,
shared_dequant_scale: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
token_dispatcher = get_forward_context().token_dispatcher
results = token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
shared_gate_up=shared_gate_up,
shared_dequant_scale=shared_dequant_scale,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=with_quant)
expert_output = unified_apply_mlp(
hidden_states=results["hidden_states"],
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=results["group_list"],
dynamic_scale=results.get("dynamic_scale"),
group_list_type=results.get("group_list_type"),
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=results.get("topk_scales"),
with_quant=with_quant)
final_hidden_states = token_dispatcher.token_combine(expert_output)
return final_hidden_states
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_model_len = vllm_config.model_config.max_model_len
get_ascend_config()
try:
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = None
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False,
enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs,
) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance and not self.use_aclgraph:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
return unified_fused_experts_eager(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
expert_map=expert_map,
shared_experts=shared_experts,
mc2_mask=kwargs.get(
"mc2_mask", None),
with_quant=False)
class AscendFusedMoE(FusedMoE):
# The moe_counter parameter is required during the initialization of EPLB
# to identify the current layer index within the MOE model.
moe_counter = -1
def __init__(
self,
num_experts: int, # Global number of experts
top_k: int,
hidden_size: int,
intermediate_size: int,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
ep_size: Optional[int] = None,
dp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
)
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
if params_dtype is None:
params_dtype = torch.get_default_dtype()
vllm_config = get_current_vllm_config()
self.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size
if dp_size is not None else get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config)
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
self.log2phy = None
self.global_redundant_expert_num = 0
is_deepseek_v3_r1 = self.global_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
self.all_reduce_merge = get_all_reduce_merge_state(
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
ascend_config = get_ascend_config()
expert_map_path = ascend_config.expert_map_path
if expert_map_path and os.path.exists(expert_map_path):
# moe expert load balance
expert_load_balancer = ExpertLoadBalancer(expert_map_path,
self.global_num_experts)
self.local_num_experts, self.expert_map = \
expert_load_balancer.get_rank_placement_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.log2phy = expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id,
get_ep_group().rank_in_group)
self.global_redundant_expert_num = \
expert_load_balancer.get_global_redundant_expert_num()
else:
# Create a tensor of size num_experts filled with -1
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
quant_config=quant_config)
self.moe_config = moe
if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.ep_group = get_ep_group()
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.token_dispatcher = None
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
setup_token_dispatchers(
ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_global_redundant_experts=self.global_redundant_expert_num,
num_local_experts=self.local_num_experts)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
get_dp_group().broadcast(buffer[start:end, :], idx)
return buffer
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False,
_metadata_for_padding: Optional[MetadataForPadding] = None):
assert self.quant_method is not None
if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
if shared_experts:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
mc2_mask = forward_context.mc2_mask
enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
if not self.enable_shared_expert_dp:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
max_tokens_across_dp = forward_context.max_tokens_across_dp
if num_tokens < max_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_dp_cpu)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=None,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
if shared_experts:
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce and not self.enable_shared_expert_dp):
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
elif self.dp_size > 1 and not self.enable_shared_expert_dp:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
final_hidden_states = get_dp_group().all_reduce(
e_hidden_states)
final_hidden_states = final_hidden_states[start:end, :]
dispose_tensor(e_hidden_states)
elif fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = data_parallel_reduce_scatter(
e_hidden_states, dim=0)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and not self.all_reduce_merge and fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
]:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if shared_experts:
return final_hidden_states, shared_hidden_states
else:
return final_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_fused_moe_comp(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
real_top_k,
enable_force_load_balance: bool = False,
):
hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
)
return hidden_states

View File

@@ -0,0 +1,85 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
class AddRMSNormW8A8Quant(RMSNorm):
# Fuse AddRmsNorm and W8A8 quantization ops together
def __init__(
self,
hidden_size: int,
layer: torch.nn.Module,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.layer = layer
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
self.layer.aclnn_input_scale,
self.layer.aclnn_input_offset,
epsilon=self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
class AscendRMSNorm(RMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x

View File

View File

@@ -0,0 +1,283 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
import torch
import torch_npu
def return_row_idx(hidden_states, top_k):
num_tokens = hidden_states.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=hidden_states.device).view(
top_k, -1).permute(1, 0).contiguous())
return row_idx
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
global_num_experts: int = -1):
"""
Fused experts with select experts.
Args:
router_logits: router logits of shape (num_tokens, hidden_size).
hidden_states: Hidden states of shape (num_tokens, hidden_size).
top_k: number of top k experts.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
indices_type: dtype of indices
global_num_experts: Global number of experts.
Returns:
topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
renormalize=renormalize,
e_score_correction_bias=e_score_correction_bias,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
if topk_weights is None:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if row_idx is None:
row_idx = return_row_idx(hidden_states, top_k)
return topk_weights, topk_ids, row_idx
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights
def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int],
renormalize: bool, top_k: int, num_expert_group: Optional[int],
e_score_correction_bias: Optional[torch.Tensor]):
assert topk_group is not None
assert num_expert_group is not None
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):
topk_weights, topk_ids, row_idx = None, None, None
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
is_deepseek_v3_r1 = global_num_experts == 256
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
row_idx = return_row_idx(hidden_states, top_k)
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids, row_idx
def _native_select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
Raises:
ValueError: If an unsupported scoring function is provided.
"""
if scoring_func == "softmax":
topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if use_grouped_topk:
return _select_expert_use_group_topk(
topk_weights=topk_weights,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias)
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

View File

@@ -0,0 +1,199 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from typing import Optional
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.utils import dispose_tensor, is_310p
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
if dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
else:
pertoken_scale = dynamic_scale
bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
if w1_scale_bias is None and is_mc2:
w1_scale = w1_scale.to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
group_list = torch.cat(
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias]
bias2 = [w2_scale_bias]
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale],
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states
def unquant_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
w1 = w1.transpose(1, 2)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
if topk_scales is not None:
gate_up_out *= topk_scales
w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
return hidden_states
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False) -> torch.Tensor:
if with_quant:
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales)

309
vllm_ascend/ops/linear.py Normal file
View File

@@ -0,0 +1,309 @@
"""
Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
This file is a part of the vllm-ascend project.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Union
import torch
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
LinearBase,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (
get_mlp_tensor_model_parallel_rank,
get_mlp_tensor_model_parallel_world_size, get_mlp_tp_group)
class AscendMlpColumnParallelLinear(ColumnParallelLinear):
"""Linear layer with column parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the last dimension.
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendMlpRowParallelLinear(RowParallelLinear):
"""Linear layer with row parallelism.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
if prefix.find("down_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
# Divide the weight matrix along the first dimension.
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
LinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.enable_mlp_optimze:
tp_rank = get_mlp_tensor_model_parallel_rank()
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_mlp_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output = get_mlp_tp_group().reduce_scatter(output_parallel, 0)
# output = output[:num_tokens,:]
# dispose_tensor(output_parallel)
else:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0
or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class AscendMlpMergedColumnParallelLinear(MergedColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Use the MLP tensor parallelism group in the MLP module,
and the original TP group in other modules.
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
if prefix.find("gate_up_proj") != -1:
self.tp_size = get_mlp_tensor_model_parallel_world_size()
self.tp_rank = get_mlp_tensor_model_parallel_rank()
self.enable_mlp_optimze = True
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.enable_mlp_optimze = False
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
AscendMlpColumnParallelLinear.__init__(self,
input_size=input_size,
output_size=sum(output_sizes),
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
# Matrix multiply.
assert self.quant_method is not None
if self.enable_mlp_optimze:
input2_ = get_mlp_tp_group().all_gather(input_, 0)
output = self.quant_method.apply(self, input2_, bias)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias

View File

@@ -0,0 +1,809 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.distributed.tensor_parallel import \
gather_from_sequence_parallel_region
from vllm_ascend.ops.comm_utils import async_all_to_all
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
"""
Initialize the MoE Token Dispatcher.
"""
self.top_k = kwargs.get("top_k", 0)
self.num_experts = kwargs.get("num_experts", 0)
@property
def ep_group(self):
"""Get expert model parallel group."""
return get_ep_group().device_group
@property
def ep_rank(self):
return get_ep_group().rank_in_group
@property
def ep_size(self):
return get_ep_group().world_size
@abstractmethod
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (
get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_soc_version() == AscendSocVersion.A3
self.output = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
self.with_quant = False
def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
quant_mode = 2
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
else:
quant_mode = 0
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.need_extra_args:
stage1_kwargs.update({
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
global_redundant_expert_num)
self.output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, self.assist_info_for_combine, \
expert_token_nums, self.ep_recv_counts = self.output[0:5]
if self.with_quant:
if shared_experts is not None:
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
else:
if shared_experts is not None:
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 1
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
assert self.topk_weights is not None
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
# moeCombine
kwargs_mc2 = {
"expand_x": hidden_states,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
else:
tp_recv_counts = self.output[5]
stage3_kwargs = {
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
self.assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": self.assist_info_for_combine,
})
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if self.shared_experts is None:
return hidden_states
else:
if self.with_quant:
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
else:
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
return hidden_states, shared_hidden_states
class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
self.original_shape = None
self.mask = None
self.expert_map = None
self.topk_weights = None
self.topk_ids = None
self.with_quant = False
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
device = hidden_states.device
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
# Generate token indices and flatten
token_indices = (torch.arange(
num_tokens, device=device,
dtype=torch.int64).unsqueeze(1).expand(-1,
self.top_k).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = expert_map[experts_flat]
# Filter valid token-expert pairs
self.mask = local_experts_flat != -1
filtered_weights = torch.where(
self.mask, weights_flat,
torch.zeros_like(weights_flat)).to(dtype)
filtered_experts = torch.where(
self.mask, local_experts_flat,
torch.full_like(local_experts_flat,
self.num_experts_local)).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(self.num_experts_local + 1,
device=device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64),
ones)
token_counts = token_counts[:self.num_experts_local]
# Rearrange hidden_states
sorted_hidden_states = hidden_states[self.sorted_token_indices]
if self.with_quant:
group_list_type = 1
expert_tokens = token_counts
else:
expert_tokens = torch.cumsum(token_counts,
dim=0,
dtype=torch.int64)
group_list_type = 0
else:
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=active_num)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, self.num_experts_local)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.original_shape is not None
dtype = hidden_states.dtype
device = hidden_states.device
if self.expert_map is not None:
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
weighted_down_out = hidden_states * \
self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros(*self.original_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# This created multiple NaN and index_add_ will mix them up which harms accuracy
# remove this mask and filter after it being fixed
num_valid_tokens = self.mask.sum()
valid_token_mask = torch.arange(
0, self.sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, self.sorted_token_indices,
valid_output)
else:
if self.with_quant:
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=self.topk_weights,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(
self.original_shape)
else:
scales = torch.ones_like(
self.topk_weights
) if self.apply_router_weight_on_input else self.topk_weights
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=scales,
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
return final_hidden_states
# mypy: disable-error-code="override"
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.bsz = None
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
self.bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0,
self.local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
unsorted_hidden_states = hidden_states.index_select(
0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
self.bsz, self.top_k // self.local_ep, -1).sum(1)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.num_global_redundant_experts = kwargs.get(
"num_global_redundant_experts", 0)
self.num_experts = self.num_experts + self.num_global_redundant_experts
self.hidden_shape = None
self.topk_weights = None
self.input_splits = None
self.output_splits = None
self.hidden_shape_before_permute = None
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
# to each local expert by all ranks.
self.num_global_tokens_per_local_expert = None
# cached intermediate tensors.
self.tokens_per_expert = None
self.global_input_tokens_local_experts_indices = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
[i % self.num_local_experts for i in range(self.num_experts)],
dtype=torch.int32,
device=torch.npu.current_device(),
)
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
self.local_expert_indices = [
local_expert_indices_offset + i
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] ==
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"
assert topk_ids.dim() == 2, "Expected 2D tensor for routing map"
if log2phy is not None:
topk_ids = log2phy[topk_ids]
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = self._dispatch_preprocess(
hidden_states, topk_ids)
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens,
self.output_splits,
self.input_splits,
self.ep_group,
)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
global_input_tokens, dynamic_scale = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all)
return {
"hidden_states": global_input_tokens,
"group_list": tokens_per_expert,
"dynamic_scale": dynamic_scale,
"group_list_type": 1
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
hidden_states = self._combine_preprocess(hidden_states)
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states, self.input_splits, self.output_splits,
self.ep_group)
handle.wait()
hidden_states.untyped_storage().resize_(0)
output = self._combine_postprocess(permutated_local_input_tokens)
self.input_splits = None
self.output_splits = None
self.num_global_tokens_per_local_expert = None
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
tokens_per_expert = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states,
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
)
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
def _preprocess(self, topk_ids: torch.Tensor) -> torch.Tensor:
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size
# Dropless
self.num_out_tokens = topk_ids.numel()
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
if self.num_local_experts > 1:
if self.num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
self.num_global_tokens_per_local_expert.ravel())
return num_tokens_per_local_expert
def _dispatch_postprocess(self, global_input_tokens, dynamic_scale=None):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, None
# Handle quantized case
if self.with_quant:
assert self.global_input_tokens_local_experts_indices is not None, \
"global_input_tokens_local_experts_indices must be initialized before calling _dispatch_postprocess"
expert_idx_2d = self.global_input_tokens_local_experts_indices.unsqueeze(
-1)
active_num = self.global_input_tokens_local_experts_indices.numel()
# Handle case with no active tokens
if active_num <= 0:
self.reversed_global_input_permutation_mapping = self.global_input_tokens_local_experts_indices
return global_input_tokens, dynamic_scale
# Process with active tokens
global_input_tokens, self.reversed_global_input_permutation_mapping, _, expanded_scale = torch_npu.npu_moe_init_routing_v2(
global_input_tokens,
expert_idx_2d,
scale=dynamic_scale,
active_num=active_num,
expert_capacity=0,
expert_num=self.num_local_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[0, self.num_local_experts],
quant_mode=-1,
row_idx_type=0)
return global_input_tokens, expanded_scale
# Handle non-quantized case
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens,
self.global_input_tokens_local_experts_indices)
return global_input_tokens, None
def _combine_preprocess(self, hidden_states):
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, self.reversed_global_input_permutation_mapping)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens):
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=self.reversed_local_input_permutation_mapping.to(
torch.int32),
probs=self.topk_weights,
restore_shape=self.hidden_shape_before_permute)
# Reshape the output tensor
output = output.view(self.hidden_shape)
return output

View File

@@ -0,0 +1,339 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
from typing import Optional, Tuple
import torch
import torch_npu
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import enable_custom_op, is_310p
def _custom_rotary_embedding_enabled(query, neox_style, head_size):
return query.dtype == torch.float16 and neox_style and head_size % 32 == 0 and enable_custom_op(
)
def _rope_forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query_shape, key_shape = query.shape, key.shape
if self.cos_sin_cache.device != query.device:
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
neox_style = self.is_neox_style
if is_neox_style_override is not None:
neox_style = is_neox_style_override
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(query, neox_style,
self.head_size) and not is_310p():
query, key = torch.ops._C.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
if offsets is not None:
raise NotImplementedError(
"Batched rotary embedding is currently not supported on NPU.")
else:
if self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
q_rot = query[..., :self.rotary_dim]
q_pass = query[..., self.rotary_dim:]
k_rot = key[..., :self.rotary_dim]
k_pass = key[..., self.rotary_dim:]
q_rot = q_rot.contiguous().view(num_tokens, -1)
k_rot = k_rot.contiguous().view(num_tokens, -1)
torch_npu._npu_rotary_embedding(
positions,
q_rot,
k_rot,
self.head_size,
self.cos_sin_cache,
neox_style,
)
q_rot = q_rot.view(num_tokens, -1, self.rotary_dim)
k_rot = k_rot.view(num_tokens, -1, self.rotary_dim)
q = torch.cat((q_rot, q_pass), dim=-1).reshape(query_shape)
k = torch.cat((k_rot, k_pass), dim=-1).reshape(key_shape)
return q, k
# TODO: Remove the contiguous in the future.
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(key.shape[0], -1)
torch_npu._npu_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
neox_style,
)
return query.view(query_shape), key.view(key_shape)
class AscendRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
is_neox_style_override: Optional[bool] = None,
):
return _rope_forward_oot(
self,
positions,
query,
key,
offsets,
is_neox_style_override,
)
class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
# Note: we adopt the native huggingface deepseek rope initialization code from
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py for
# its more ascend compute friendly
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
self._yarn_get_mscale(self.scaling_factor, float(mscale)) /
self._yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) *
attn_factor)
super(DeepseekScalingRotaryEmbedding,
self).__init__(head_size, rotary_dim, max_position_embeddings,
base, is_neox_style, dtype)
self.max_seq_len = max_position_embeddings
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=NPUPlatform.device_type,
dtype=dtype)
def _yarn_get_mscale(self, scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def _rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _yarn_linear_ramp_mask(self, min_value, max_value, dim):
# Note: The if conditional branch is not used here
# to solve MTP compilation error.
max_value += (min_value == max_value).float() * 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) -
min_value) / (max_value - min_value)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(self,
num_rotations,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
return (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
# Find dim range bounds based on rotations
def _yarn_find_correction_range(self,
low_rot,
high_rot,
dim,
base=10000,
max_position_embeddings=2048):
# Note: use torch instead of math to solve MTP compilation error.
low = torch.floor(
self._yarn_find_correction_dim(low_rot, dim, base,
max_position_embeddings))
high = torch.ceil(
self._yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
# Note: use torch instead of max/min to solve MTP compilation error.
return torch.clamp(low, min=0), torch.clamp(high, max=dim - 1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def _apply_rotary_pos_emb(self,
q,
k,
cos,
sin,
position_ids,
unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
if len(q.shape) == 3:
q = q[:, :, None, :]
if len(k.shape) == 2:
k = k[:, None, None, :]
elif len(k.shape) == 3:
k = k[:, :, None, :]
b, h_q, s, d = q.shape
q = q.view(b, h_q, s, d // 2, 2).transpose(4, 3).reshape(b, h_q, s, d)
b, h_k, s, d = k.shape
k = k.view(b, h_k, s, d // 2, 2).transpose(4, 3).reshape(b, h_k, s, d)
q_embed = (q * cos) + (self._rotate_half(q) * sin)
k_embed = (k * cos) + (self._rotate_half(k) * sin)
q_embed = q_embed.view(b, h_q, d)
k_embed = k_embed.view(b, h_k, d)
return q_embed, k_embed
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
dim = self.rotary_dim
freq_extra = 1.0 / (self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
freq_inter = 1.0 / (self.scaling_factor * self.base**(
torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
low, high = self._yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.max_position_embeddings,
)
inv_freq_mask = 1.0 - self._yarn_linear_ramp_mask(
low, high, dim // 2).to(device=device, dtype=torch.float32)
inv_freq = freq_inter * (1 -
inv_freq_mask) + freq_extra * inv_freq_mask
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(seq_len * self.scaling_factor,
device=device,
dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos_cached = torch.cat([freqs, freqs], dim=-1).cos() * self.mscale
sin_cached = torch.cat([freqs, freqs], dim=-1).sin() * self.mscale
cos_cached = cos_cached.to(dtype)
sin_cached = sin_cached.to(dtype)
cache = torch.cat(
[freqs.cos() * self.mscale,
freqs.sin() * self.mscale], dim=-1).to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
max_seq_len: Optional[int] = None):
if max_seq_len is not None and max_seq_len > self.max_seq_len:
self._set_cos_sin_cache(max_seq_len, query.device, query.dtype)
if len(key.shape) == 2:
key = key[:, None, :]
# Note: we implement the non neox_style method with shuffle the last dim and neox style
# calculation method which is also more compute friendly to the ascend machine
# https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py
neox_style = True
if self.is_neox_style is False:
b, h_q, d = query.shape
query = query.view(b, h_q, d // 2,
2).transpose(3, 2).reshape(b, h_q, d)
b, h_k, d = key.shape
key = key.view(b, h_k, d // 2, 2).transpose(3,
2).reshape(b, h_k, d)
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, offsets,
neox_style)
return q_pe, k_pe

View File

@@ -0,0 +1,120 @@
import torch
from torch.nn import functional as F
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm_ascend.platform import NPUPlatform
class MetadataForPadding:
def __init__(self,
padding_flag=False,
lengths_sum_padding=0,
lengths_sum_unpadding=0,
pad_size=0,
not_dummy_and_is_prefill=False):
self.padding_flag = padding_flag
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
self.lengths_sum_padding = lengths_sum_padding
self.lengths_sum_unpadding = lengths_sum_unpadding
self.pad_size = pad_size
self.tp_size = get_tp_group().world_size
self.tp_rank_in_group = get_tp_group().rank_in_group
assert self.lengths_sum_padding % self.tp_size == 0
self.slice_size = self.lengths_sum_padding // self.tp_size
self.mc2_mask = torch.zeros(
self.lengths_sum_padding,
dtype=torch.bool,
device=NPUPlatform.device_type,
)
self.mc2_mask[:lengths_sum_unpadding] = True
def padding_aligned_reduce_scatter(self,
data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
padded_data, 0)
return padded_data_reduce_scatter
def allgather_unpadding_aligned(self,
padded_data: torch.Tensor) -> torch.Tensor:
padded_data_allgather = tensor_model_parallel_all_gather(
padded_data, 0)
if self.padding_flag:
lengths_sum_unpadding = self.lengths_sum_unpadding
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
else:
unpadding_data = padded_data_allgather
return unpadding_data
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
start = self.tp_rank_in_group * self.slice_size
end = start + self.slice_size
slice_data = padded_data[start:end]
return slice_data
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
# padded_data = data
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
return padded_data_reduce_scatter
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
if not enable_sequence_parallelism:
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)
is_perifll = 0
attn_metadata = get_forward_context().attn_metadata
tp_size = get_tensor_model_parallel_world_size()
if attn_metadata is not None:
if hasattr(attn_metadata,
'is_only_prefill') and attn_metadata.is_only_prefill:
is_perifll = 1
if hasattr(attn_metadata,
'num_prefills') and attn_metadata.num_prefills > 0:
is_perifll = 1
if is_perifll:
lengths_sum_unpadding = input_ids.shape[0]
lengths_sum_padding = (
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
if lengths_sum_unpadding == lengths_sum_padding:
padding_flag = False
else:
padding_flag = True
pad_size = lengths_sum_padding - lengths_sum_unpadding
_metadata_for_padding = MetadataForPadding(
lengths_sum_unpadding=lengths_sum_unpadding,
lengths_sum_padding=lengths_sum_padding,
padding_flag=padding_flag,
pad_size=pad_size,
not_dummy_and_is_prefill=True)
return _metadata_for_padding
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)

View File

@@ -0,0 +1,254 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.distributed import divide, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
VocabParallelEmbedding, pad_vocab_size)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import get_lmhead_tp_group
from vllm_ascend.utils import lmhead_tp_enable
class AscendVocabParallelEmbedding(VocabParallelEmbedding):
"""
Register VocabParallelEmbedding as a custom op for Ascend.
AscendVocabParallelEmbedding support different communication parallel groups
Added the feature of lmheadTP in pure dp scenario
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
self.comm_group = get_lmhead_tp_group()
else:
self.comm_group = get_tp_group()
self.tp_size = self.comm_group.world_size
self.tp_rank = self.comm_group.rank_in_group
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size,
self.tp_rank, self.tp_size)
self.embedding_dim = embedding_dim
quant_method = None
if quant_config is not None:
quant_method = quant_config.get_quant_method(self, prefix=prefix)
if quant_method is None:
quant_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _get_masked_input_and_mask(
self, input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (
input_ < org_vocab_end_index)
# Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
if added_vocab_start_index == added_vocab_end_index:
valid_offset = (org_vocab_start_index * org_vocab_mask)
vocab_mask = org_vocab_mask
else:
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index -
org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
# Adapt end.
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = self._get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = self.quant_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class AscendParallelLMHead(ParallelLMHead):
"""
Register ParallelLMHead as a custom op for Ascend."""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
AscendVocabParallelEmbedding.__init__(self, num_embeddings,
embedding_dim, params_dtype,
org_num_embeddings, padding_size,
quant_config, prefix)
self.quant_config = quant_config
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
class AscendLogitsProcessor(LogitsProcessor):
"""
Register LogitsProcessor as a custom op for Ascend.
Added the feature of lmheadTP in pure dp scenario
"""
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
if lmhead_tp_enable():
return self._get_logits_lmheadtp(hidden_states, lm_head,
embedding_bias)
else:
return self._get_logits_normal(hidden_states, lm_head,
embedding_bias)
def _get_logits_lmheadtp(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Gather hidden states from all devices in tensor parallel group
gathered_hidden_states = get_lmhead_tp_group().all_gather(
hidden_states, dim=0)
local_logits = lm_head.quant_method.apply(lm_head,
gathered_hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = get_lmhead_tp_group().all_to_all(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits
def _get_logits_normal(
self,
hidden_states: torch.Tensor,
lm_head: AscendParallelLMHead,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
local_logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for tensor parallel
logits = self._gather_logits(local_logits)
# Remove paddings in vocab (if any)
if logits is not None:
logits = logits[..., :self.org_vocab_size]
return logits