[2/N][Feat] Attention and MoE weight prefetch in Qwen3MoE models (#3203)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `gate_up_proj.weight` in quantized Attention modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": True,
        "prefetch_ratio": {
            "moe": {
                "gate_up": 0.8
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
This commit is contained in:
yuzhup
2025-10-14 20:16:33 +08:00
committed by GitHub
parent 07e39620ea
commit 78777237a9
9 changed files with 160 additions and 100 deletions

View File

@@ -73,10 +73,10 @@ ascend_scheduler_config also support the options from [vllm scheduler config](ht
**weight_prefetch_config** **weight_prefetch_config**
| Name | Type | Default | Description | | Name | Type | Default | Description |
|------------------|------|------------------------------------|------------------------------------| |------------------|------|-------------------------------------------------------------|------------------------------------|
| `enabled` | bool | `False` | Whether to enable weight prefetch. | | `enabled` | bool | `False` | Whether to enable weight prefetch. |
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. | | `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weights. |
### Example ### Example
@@ -104,6 +104,9 @@ An example of additional configuration is as follows:
"qkv": 1.0, "qkv": 1.0,
"o": 1.0, "o": 1.0,
}, },
"moe": {
"gate_up": 0.8
}
}, },
}, },
"multistream_overlap_shared_expert": True, "multistream_overlap_shared_expert": True,

View File

@@ -291,7 +291,9 @@ def test_select_experts(
custom_routing_function.return_value = (mock_weights, mock_ids) custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk" with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk: ) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x) x)
@@ -325,7 +327,9 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE) @pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str): def test_select_experts_invalid_scoring_func(device: str):
with pytest.raises(ValueError, with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
pytest.raises(ValueError,
match="Unsupported scoring function: invalid"): match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device), select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device), router_logits=torch.randn(1, 8, device=device),

View File

@@ -92,14 +92,16 @@ def mock_dist_env(mocker: MockerFixture):
mock_moe_comm_method.finalize.side_effect = mock_finalize mock_moe_comm_method.finalize.side_effect = mock_finalize
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, mock_weight_prefetch_method = MagicMock()
moe_comm_type=MoECommType.MC2, mock_forward_context_obj = MagicMock(
max_tokens_across_dp=10, moe_comm_method=mock_moe_comm_method,
dp_metadata=dp_metadata, moe_comm_type=MoECommType.MC2,
mc2_mask=torch.zeros( max_tokens_across_dp=10,
16, dtype=torch.bool), dp_metadata=dp_metadata,
padded_num_tokens=16, mc2_mask=torch.zeros(16, dtype=torch.bool),
with_quant=False) padded_num_tokens=16,
with_quant=False,
weight_prefetch_method=mock_weight_prefetch_method)
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
@@ -132,7 +134,9 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
return_value=None), \ return_value=None), \
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher', patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
return_value=None): return_value=None), \
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
return_value=mock_forward_context_obj):
yield { yield {
'mock_forward_context_obj': mock_forward_context_obj, 'mock_forward_context_obj': mock_forward_context_obj,

View File

@@ -755,6 +755,14 @@ class TestSelectExperts(TestBase):
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size) self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts) self.router_logits = torch.randn(self.num_tokens, self.num_experts)
self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock()
patcher = patch(
'vllm_ascend.ops.moe.experts_selector.get_forward_context',
return_value=self.mock_ctx)
self.addCleanup(patcher.stop)
patcher.start()
@patch('torch_npu.npu_moe_gating_top_k_softmax') @patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_softmax_scoring(self, mock_topk): def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function""" """Test softmax scoring function"""

View File

@@ -216,6 +216,9 @@ class WeightPrefetchConfig:
"qkv": 1.0, "qkv": 1.0,
"o": 1.0, "o": 1.0,
}, },
"moe": {
"gate_up": 0.8
}
} }
def __init__(self, weight_prefetch_config: dict): def __init__(self, weight_prefetch_config: dict):

View File

@@ -145,7 +145,7 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
# TODO(yuzhup): integrate moe weight prefetch method forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method forward_context.weight_prefetch_method = weight_prefetch_method
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.

View File

@@ -18,6 +18,7 @@ from typing import Callable, Optional
import torch import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context
def return_row_idx(hidden_states, top_k): def return_row_idx(hidden_states, top_k):
@@ -65,7 +66,11 @@ def select_experts(hidden_states: torch.Tensor,
topk_weights: router weights of shape (num_tokens, top_k). topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k). topk_ids: selected expert IDs of shape (num_tokens, top_k).
""" """
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_forward_context().weight_prefetch_method
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops( topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,

View File

@@ -78,6 +78,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias1, bias2 = None, None bias1, bias2 = None, None
_output_dtype = w2_scale.dtype _output_dtype = w2_scale.dtype
weight_prefetch_method = get_forward_context().weight_prefetch_method
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2: if w1_scale_bias is None and is_mc2:
if fusion and not dynamic_eplb: if fusion and not dynamic_eplb:

View File

@@ -1,83 +1,112 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear, from vllm_ascend.ascend_config import WeightPrefetchConfig
AscendRowParallelLinear) from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
@dataclass
class ModuleWeightPrefetchConfig:
module_name: str @dataclass
enable: bool = False class ModuleWeightPrefetchConfig:
prefetch_ratio: dict = field(default_factory=dict) module_name: str
linear_prefix_map: dict = field(default_factory=dict) enable: bool = False
is_active_this_forward: bool = False
def __post_init__(self) -> None: prefetch_ratio: dict = field(default_factory=dict)
self.prefetch_ratio = { linear_prefix_map: dict = field(default_factory=dict)
prefix: ratio
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1 def __post_init__(self) -> None:
} self.prefetch_ratio = {
prefix: ratio
assert self.module_name in SUPPORTED_MODULES, ( for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}" }
)
assert self.module_name in SUPPORTED_MODULES, (
if self.module_name in SUPPORTED_MODULES: f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0 )
if self.module_name in SUPPORTED_MODULES:
class WeightPrefetchMethod: self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
"""
Unified weight prefetch method.
""" class WeightPrefetchMethod:
"""
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None: Unified weight prefetch method.
self.attn = ModuleWeightPrefetchConfig( """
module_name="attn",
enable=weight_prefetch_config.enabled, def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( self.attn = ModuleWeightPrefetchConfig(
"attn", {}), module_name="attn",
linear_prefix_map={ enable=weight_prefetch_config.enabled,
AscendQKVParallelLinear.__name__: "qkv", prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
AscendRowParallelLinear.__name__: "o", "attn", {}),
}) linear_prefix_map={
AscendQKVParallelLinear.__name__: "qkv",
def maybe_prefetch_attn_weight_preprocess( AscendRowParallelLinear.__name__: "o",
self, layer_cls_name: str, weight: torch.Tensor, })
start_flag: torch.Tensor) -> None: self.moe = ModuleWeightPrefetchConfig(
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: module_name="moe",
return enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") "moe", {}))
weight_size = weight.data.element_size() * weight.data.numel(
) * self.attn.prefetch_ratio.get(prefix, 0) def maybe_prefetch_attn_weight_preprocess(
self, layer_cls_name: str, weight: torch.Tensor,
torch.ops.vllm.prefetch_preprocess(weight=weight, start_flag: torch.Tensor) -> None:
start_flag=start_flag, if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
max_weight_size=int(weight_size)) return
def maybe_prefetch_attn_weight_postprocess( prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: weight_size = weight.data.element_size() * weight.data.numel(
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: ) * self.attn.prefetch_ratio.get(prefix, 0)
return
torch.ops.vllm.prefetch_preprocess(weight=weight,
torch.ops.vllm.prefetch_postprocess(stop_flag) start_flag=start_flag,
max_weight_size=int(weight_size))
def maybe_npu_prefetch(inputs: torch.Tensor, def maybe_prefetch_attn_weight_postprocess(
dependency: torch.Tensor, self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
max_size: int = 0, if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
offset: int = 0, return
*,
enabled: bool = True) -> None: torch.ops.vllm.prefetch_postprocess(stop_flag)
if not enabled:
return def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
input_size = inputs.element_size() * inputs.numel() self.moe.is_active_this_forward = hidden_states.shape[
if max_size <= 0 or max_size > input_size: 0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
max_size = input_size if not self.moe.is_active_this_forward:
torch_npu.npu_prefetch(inputs, dependency, max_size, offset) return
forward_context = get_forward_context()
weight = forward_context.model_instance.model.layers[
forward_context.layer_idx].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel(
) * self.moe.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=None,
max_weight_size=int(weight_size))
forward_context.layer_idx += 1
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.moe.is_active_this_forward:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
offset: int = 0,
*,
enabled: bool = True) -> None:
if not enabled:
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)