[1/N][Draft][Refactor]torchair pangu_moe modeling refactor (#2437)
### What this PR does / why we need it?
1. Similar to #2384 , this PR add a torchair-specific modeling for
pangu.
2. Fixes a bug introduced by routed_scaling_factor in #2675 .
3. remove eager test case for pangu since there has already been a
torchair test case.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
---------
Signed-off-by: zengyanjia <z00883269@china.huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Co-authored-by: zengyanjia <z00883269@china.huawei.com>
This commit is contained in:
@@ -22,8 +22,6 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`.
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
|
||||||
@@ -155,7 +153,6 @@ def _pangu_torchair_test_fixture(
|
|||||||
print(f"Generated text: {vllm_output[i][1]!r}")
|
print(f"Generated text: {vllm_output[i][1]!r}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("pangu doesn't work, fix me")
|
|
||||||
def test_e2e_pangu_with_torchair():
|
def test_e2e_pangu_with_torchair():
|
||||||
additional_config = {
|
additional_config = {
|
||||||
"torchair_graph_config": {
|
"torchair_graph_config": {
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class TestTorchairUtils(TestBase):
|
|||||||
mock_model_registry.return_value = mock_registry
|
mock_model_registry.return_value = mock_registry
|
||||||
utils.register_torchair_model()
|
utils.register_torchair_model()
|
||||||
|
|
||||||
self.assertEqual(mock_model_registry.register_model.call_count, 5)
|
self.assertEqual(mock_model_registry.register_model.call_count, 6)
|
||||||
call_args_list = mock_model_registry.register_model.call_args_list
|
call_args_list = mock_model_registry.register_model.call_args_list
|
||||||
|
|
||||||
expected_registrations = [
|
expected_registrations = [
|
||||||
@@ -81,7 +81,11 @@ class TestTorchairUtils(TestBase):
|
|||||||
("Qwen2ForCausalLM",
|
("Qwen2ForCausalLM",
|
||||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
|
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
|
||||||
("Qwen3MoeForCausalLM",
|
("Qwen3MoeForCausalLM",
|
||||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
|
||||||
|
),
|
||||||
|
("PanguProMoEForCausalLM",
|
||||||
|
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
for i, (expected_name,
|
for i, (expected_name,
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||||
|
|
||||||
_ROUTER_SCALE = None
|
_ROUTER_SCALE = None
|
||||||
@@ -612,9 +611,6 @@ class PanguProMoEAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -625,18 +621,7 @@ class PanguProMoEAttention(nn.Module):
|
|||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
if self.torchair_graph_enabled:
|
attn_output = self.attn(q, k, v)
|
||||||
forward_kwargs = {'trace_flag': False}
|
|
||||||
output_shape = q.shape
|
|
||||||
attn_output = torch.empty(output_shape,
|
|
||||||
dtype=q.dtype,
|
|
||||||
device=q.device)
|
|
||||||
forward_kwargs['output'] = attn_output
|
|
||||||
attn_output = self.attn.impl.forward(self.attn, q, k, v, kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
**forward_kwargs)
|
|
||||||
else:
|
|
||||||
attn_output = self.attn(q, k, v)
|
|
||||||
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -170,15 +170,6 @@ def fused_experts_moge(
|
|||||||
local_num_experts = global_num_experts // ep_size
|
local_num_experts = global_num_experts // ep_size
|
||||||
local_num_group = top_k // ep_size
|
local_num_group = top_k // ep_size
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
|
||||||
assert (topk_weights.dim() == 2
|
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
|
||||||
|
|
||||||
bsz, _ = hidden_states.shape
|
bsz, _ = hidden_states.shape
|
||||||
flatten_topk_ids = topk_ids.view(-1)
|
flatten_topk_ids = topk_ids.view(-1)
|
||||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||||
@@ -407,6 +398,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
prefix="",
|
prefix="",
|
||||||
custom_routing_function=None,
|
custom_routing_function=None,
|
||||||
scoring_func="softmax",
|
scoring_func="softmax",
|
||||||
|
routed_scaling_fator: float = 1.0,
|
||||||
e_score_correction_bias=None,
|
e_score_correction_bias=None,
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input=False,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
@@ -414,31 +406,59 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
num_redundant_experts=0,
|
num_redundant_experts=0,
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||||
num_experts,
|
super().__init__(
|
||||||
top_k,
|
num_experts,
|
||||||
hidden_size,
|
top_k,
|
||||||
intermediate_size,
|
hidden_size,
|
||||||
params_dtype,
|
intermediate_size,
|
||||||
reduce_results,
|
params_dtype,
|
||||||
renormalize,
|
reduce_results,
|
||||||
use_grouped_topk,
|
renormalize,
|
||||||
num_expert_group,
|
use_grouped_topk,
|
||||||
topk_group,
|
num_expert_group,
|
||||||
quant_config,
|
topk_group,
|
||||||
tp_size,
|
quant_config,
|
||||||
ep_size,
|
tp_size,
|
||||||
dp_size,
|
ep_size,
|
||||||
prefix,
|
dp_size,
|
||||||
custom_routing_function,
|
prefix,
|
||||||
scoring_func,
|
custom_routing_function,
|
||||||
e_score_correction_bias,
|
scoring_func,
|
||||||
apply_router_weight_on_input,
|
e_score_correction_bias,
|
||||||
activation,
|
apply_router_weight_on_input,
|
||||||
enable_eplb,
|
activation,
|
||||||
num_redundant_experts,
|
enable_eplb,
|
||||||
has_bias,
|
num_redundant_experts,
|
||||||
)
|
has_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
super().__init__(
|
||||||
|
num_experts,
|
||||||
|
top_k,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
params_dtype,
|
||||||
|
reduce_results,
|
||||||
|
renormalize,
|
||||||
|
use_grouped_topk,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
quant_config,
|
||||||
|
tp_size,
|
||||||
|
ep_size,
|
||||||
|
dp_size,
|
||||||
|
prefix,
|
||||||
|
custom_routing_function,
|
||||||
|
scoring_func,
|
||||||
|
routed_scaling_fator,
|
||||||
|
e_score_correction_bias,
|
||||||
|
apply_router_weight_on_input,
|
||||||
|
activation,
|
||||||
|
enable_eplb,
|
||||||
|
num_redundant_experts,
|
||||||
|
has_bias,
|
||||||
|
)
|
||||||
|
|
||||||
setup_token_dispatchers(self.moe_config.ep_size,
|
setup_token_dispatchers(self.moe_config.ep_size,
|
||||||
top_k=self.top_k,
|
top_k=self.top_k,
|
||||||
|
|||||||
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
1119
vllm_ascend/torchair/models/torchair_pangu_moe.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -173,6 +173,11 @@ def register_torchair_model():
|
|||||||
"Qwen3MoeForCausalLM",
|
"Qwen3MoeForCausalLM",
|
||||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"PanguProMoEForCausalLM",
|
||||||
|
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def torchair_quant_method_register():
|
def torchair_quant_method_register():
|
||||||
from vllm_ascend.quantization.quantizer import \
|
from vllm_ascend.quantization.quantizer import \
|
||||||
|
|||||||
Reference in New Issue
Block a user