[1/N][Draft][Refactor]torchair pangu_moe modeling refactor (#2437)
### What this PR does / why we need it?
1. Similar to #2384 , this PR add a torchair-specific modeling for
pangu.
2. Fixes a bug introduced by routed_scaling_factor in #2675 .
3. remove eager test case for pangu since there has already been a
torchair test case.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
---------
Signed-off-by: zengyanjia <z00883269@china.huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: zengyanjia <z00883269@china.huawei.com>
This commit is contained in:
@@ -22,8 +22,6 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`.
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||
@@ -155,7 +153,6 @@ def _pangu_torchair_test_fixture(
|
||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip("pangu doesn't work, fix me")
|
||||
def test_e2e_pangu_with_torchair():
|
||||
additional_config = {
|
||||
"torchair_graph_config": {
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestTorchairUtils(TestBase):
|
||||
mock_model_registry.return_value = mock_registry
|
||||
utils.register_torchair_model()
|
||||
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 5)
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 6)
|
||||
call_args_list = mock_model_registry.register_model.call_args_list
|
||||
|
||||
expected_registrations = [
|
||||
@@ -81,7 +81,11 @@ class TestTorchairUtils(TestBase):
|
||||
("Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
|
||||
("Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
|
||||
),
|
||||
("PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
]
|
||||
|
||||
for i, (expected_name,
|
||||
|
||||
@@ -57,7 +57,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
|
||||
_ROUTER_SCALE = None
|
||||
@@ -612,9 +611,6 @@ class PanguProMoEAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -625,18 +621,7 @@ class PanguProMoEAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.torchair_graph_enabled:
|
||||
forward_kwargs = {'trace_flag': False}
|
||||
output_shape = q.shape
|
||||
attn_output = torch.empty(output_shape,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
forward_kwargs['output'] = attn_output
|
||||
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
|
||||
attn_metadata,
|
||||
**forward_kwargs)
|
||||
else:
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
@@ -170,15 +170,6 @@ def fused_experts_moge(
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
@@ -407,6 +398,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
prefix="",
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
routed_scaling_fator: float = 1.0,
|
||||
e_score_correction_bias=None,
|
||||
apply_router_weight_on_input=False,
|
||||
activation="silu",
|
||||
@@ -414,31 +406,59 @@ class AscendFusedMoE(FusedMoE):
|
||||
num_redundant_experts=0,
|
||||
has_bias=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
else:
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
params_dtype,
|
||||
reduce_results,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
ep_size,
|
||||
dp_size,
|
||||
prefix,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_fator,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
num_redundant_experts,
|
||||
has_bias,
|
||||
)
|
||||
|
||||
setup_token_dispatchers(self.moe_config.ep_size,
|
||||
top_k=self.top_k,
|
||||
|
||||
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -173,6 +173,11 @@ def register_torchair_model():
|
||||
"Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
|
||||
def torchair_quant_method_register():
|
||||
from vllm_ascend.quantization.quantizer import \
|
||||
|
||||
Reference in New Issue
Block a user