[2/N][Feat] Attention and MoE weight prefetch in Qwen3MoE models (#3203)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `gate_up_proj.weight` in quantized Attention modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"moe": {
"gate_up": 0.8
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
@@ -65,7 +66,11 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -78,6 +78,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
if fusion and not dynamic_eplb:
|
||||
|
||||
Reference in New Issue
Block a user