[main] convert the format of gmm to nz (#2474)

### What this PR does / why we need it?
convert the format of gmm to nz

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
ut: test_fused_ops.py and e2e: test_fused_moe.py

**performance**:
(qwen3 30B, 2k->20k)

base:
Total Token throughput (tok/s):          719.93

gmm nz:
Total Token throughput (tok/s):          728.52


- vLLM version: v0.10.1.1
- vLLM main:
bfc1edc9f5

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-08-27 11:25:02 +08:00
committed by GitHub
parent c0e12143a3
commit 6881c19458
2 changed files with 18 additions and 7 deletions

View File

@@ -50,6 +50,10 @@ def mock_dp_and_tp_group(mocker):
return mock_group
def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
# init dist env patch
@@ -310,12 +314,14 @@ class TestAscendUnquantizedFusedMoEMethod:
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
moe_method.process_weights_after_loading(layer)
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
patch('vllm_ascend.utils.is_310p', return_value=False):
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])

View File

@@ -49,8 +49,8 @@ from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p)
@@ -866,6 +866,11 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
def apply(
self,