[Refact]Refact MLA/SFA weight prefetch to consist with moe weight prefetch (#6629)

### What this PR does / why we need it?
1. [Refact] Refact MLA/SFA weight prefetch to consist with moe weight
prefetch
2. Remove duplicated o_proj weight prefetch in forward for MLA/SFA

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?

1) Performance result:
Perf test data:
*) MLA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 11.9669 token/s | 12.0287 token/s |
11.9978 |
| o_proj no duplicate prefetch | 12.5594 token/s | 12.6216 token/s |
12.5905 | 4.94%| |

single layer performace improve: 5%~8%

*) SFA:

| | 1st test | 2nd test | Output Token Throughput(Avg) | Performance
improvement percentage |
| --- | --- | --- | --- | --- |
| o_proj duplicate prefetch | 13.0523 token/s | 13.1084 token/s |
13.08035 | |
| o_proj no duplicate prefetch | 13.9844 token/s | 14.1678 token/s |
14.0761 | 7.6% |

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
Nengjun Ma
2026-02-10 14:14:37 +08:00
committed by GitHub
parent 2a826b5fad
commit 66b60c9440
15 changed files with 98 additions and 56 deletions

View File

@@ -248,9 +248,10 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(self.impl.dcp_size, 2)
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_dcp(self, magic_npu_fetch,
def test_mla_preprocess_dcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):
self.impl.num_kv_heads = 1
@@ -309,7 +310,6 @@ class TestAscendMLAImpl(TestBase):
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
decode_res, prefill_res = self.impl._mla_preprocess(
@@ -324,9 +324,10 @@ class TestAscendMLAImpl(TestBase):
@patch('torch_npu._npu_reshape_and_cache')
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_mla_preprocess_pcp(self, magic_npu_fetch,
def test_mla_preprocess_pcp(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad,
mock_npu_reshape_and_cache):
self.impl.num_kv_heads = 1
@@ -389,7 +390,6 @@ class TestAscendMLAImpl(TestBase):
self.impl.qk_rope_head_dim)
]
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
self.impl.kv_a_layernorm = MagicMock()

View File

@@ -967,10 +967,10 @@ class TestAscendMLAImpl(TestBase):
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch,
@patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method",
return_value=MagicMock())
def test_mla_preprocess(self, mock_get_weight_prefetch_method,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8

View File

@@ -53,9 +53,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor, default_vllm_config):
@pytest.mark.skipif(is_310p_hw(), reason="non_310P device unittest case.")
@patch("vllm_ascend.ops.activation.get_weight_prefetch_method",
return_value=MagicMock())
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(
mock_swiglu,
mock_get_weight_prefetch_method,
dummy_tensor,
default_vllm_config,
):

View File

@@ -296,6 +296,8 @@ class TestCumsumGroupList(TestBase):
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
@@ -306,7 +308,8 @@ class TestUnifiedApplyMLP(TestBase):
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_soc_version,
mock_get_forward_context):
mock_get_forward_context,
mock_get_weight_prefetch_method):
mock_forward_context = MagicMock()
mock_forward_context.moe_comm_type = MoECommType.MC2
@@ -402,13 +405,16 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.HAS_TRITON', False)
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method',
return_value=MagicMock())
@patch('vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_npu_grouped_matmul, mock_get_forward_context,
mock_get_weight_prefetch_method):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
@@ -505,6 +511,8 @@ class TestUnifiedApplyMLP(TestBase):
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_weight_prefetch_method",
return_value=MagicMock())
@patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@@ -513,7 +521,8 @@ class TestUnifiedApplyMLP(TestBase):
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
mock_npu_swiglu, mock_npu_grouped_matmul,
mock_get_forward_context):
mock_get_forward_context,
mock_get_weight_prefetch_method):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True

View File

@@ -83,7 +83,9 @@ class TestAscendUnquantizedLinearMethod(TestBase):
class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self):
@patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method",
return_value=MagicMock())
def test_mlp_optimize(self, mock_get_weight_prefetch_method):
ascend_config._ASCEND_CONFIG = MagicMock()
ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False
@@ -100,7 +102,9 @@ class TestAscendRowParallelLinear(BaseLinearTest):
input_tensor = torch.randn(16, 8)
linear(input_tensor)
def test_oproj_tp(self):
@patch("vllm_ascend.ops.linear_op.get_weight_prefetch_method",
return_value=MagicMock())
def test_oproj_tp(self, mock_get_weight_prefetch_method):
config._current_vllm_config = MagicMock()

View File

@@ -57,8 +57,7 @@ def select_experts(
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,

View File

@@ -43,9 +43,13 @@ from vllm_ascend.ops.layer_shard_linear import (
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
get_weight_prefetch_method,
maybe_trans_nz,
weak_ref_tensors,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -703,7 +707,6 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.enable_kv_nz
self.ring_mla_mask_size = 512
@@ -1412,8 +1415,9 @@ class AscendMLAImpl(MLAAttentionImpl):
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
@@ -1545,14 +1549,13 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)
output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0]
del o_proj_input

View File

@@ -37,7 +37,6 @@ from vllm_ascend.ops.layer_shard_linear import (
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_ND,
@@ -45,6 +44,7 @@ from vllm_ascend.utils import (
dispose_layer,
enable_dsa_cp,
enable_dsa_cp_with_layer_shard,
get_weight_prefetch_method,
maybe_trans_nz,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -410,7 +410,6 @@ class AscendSFAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
# In sfa, prefill and decode have the same calculation formula,
# so do not distinguish between prefill and decode here.
@@ -800,8 +799,9 @@ class AscendSFAImpl(MLAAttentionImpl):
)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states
)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
@@ -917,11 +917,12 @@ class AscendSFAImpl(MLAAttentionImpl):
)
attn_output = self._v_up_proj(attn_output)
maybe_npu_prefetch(
weight_prefetch_method = get_weight_prefetch_method()
weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream(
inputs=self.o_proj.weight,
dependency=attn_output,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch,
linear_layer=self.o_proj,
)
if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only:

View File

@@ -34,9 +34,7 @@ class AscendSiluAndMul(SiluAndMul):
import torch_npu
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(weight_prefetch_method.MLP_DOWN, x)
out = torch_npu.npu_swiglu(x)
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(out)
return out

View File

@@ -59,8 +59,7 @@ def select_experts(
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,

View File

@@ -100,8 +100,7 @@ def quant_apply_mlp(
_output_dtype = w2_scale[0].dtype
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):

View File

@@ -66,8 +66,7 @@ class AscendRMSNorm(RMSNorm):
x.add_(self.bias)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
weight_prefetch_method.maybe_prefetch_mlp_weight_postprocess(x)
return x

View File

@@ -149,10 +149,9 @@ class CustomRowParallelOp(CustomLinearOp):
def apply(self, input_):
output, output_bias = self.apply_impl(input_)
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)
weight_prefetch_method.maybe_prefetch_mlp_weight_preprocess(
weight_prefetch_method.MLP_GATE_UP, output, self.prefix
)
if not self.return_bias:
return output

View File

@@ -47,6 +47,7 @@ class WeightPrefetchMethod:
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.is_moe = is_moe_model(get_current_vllm_config())
self.mla_sfa_prefetch_enable = weight_prefetch_config.enabled
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
@@ -94,6 +95,9 @@ class WeightPrefetchMethod:
if not self.moe.is_active_this_forward:
return
forward_context = get_forward_context()
if not forward_context or forward_context.model_instance is None:
return
# layer_idx is subtracted by 1 because layer_idx was incremented by 1 at layernorm.
weight = forward_context.model_instance.model.layers[forward_context.layer_idx - 1].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel() * self.moe.prefetch_ratio.get(prefix, 0)
@@ -184,6 +188,33 @@ class WeightPrefetchMethod:
forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False
def maybe_prefetch_mla_or_sla_weight_in_current_stream(
self,
inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
linear_layer: torch.nn.Module | None = None,
) -> None:
if not self.mla_sfa_prefetch_enable:
return
# The prefetching of the weights of the o_proj matrix in the W8A8
# scene is already performed once in AscendW8A8LinearMethod, so it
# is not needed here.
if linear_layer is not None:
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
if isinstance(
getattr(linear_layer.quant_method, "quant_method", None),
AscendW8A8LinearMethod,
):
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch.ops.vllm.prefetch_preprocess(weight=inputs, start_flag=dependency, max_weight_size=int(max_size))
def maybe_npu_prefetch(
inputs: torch.Tensor, dependency: torch.Tensor, max_size: int = 0, offset: int = 0, *, enabled: bool = True

View File

@@ -82,12 +82,11 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
layer_cls_name = layer.__class__.__name__
weight_prefetch_method = get_weight_prefetch_method()
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
try:
quant_comm_config = layer._quant_comm_config
except AttributeError:
@@ -117,11 +116,10 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
)
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
layer_cls_name=layer_cls_name,
stop_flag=x,
)
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
layer_cls_name=layer_cls_name,
stop_flag=x,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None