[2/N][Feat] Attention and MoE weight prefetch in Qwen3MoE models (#3203)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `gate_up_proj.weight` in quantized Attention modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"moe": {
"gate_up": 0.8
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -216,6 +216,9 @@ class WeightPrefetchConfig:
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
"moe": {
|
||||
"gate_up": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
|
||||
@@ -145,7 +145,7 @@ def set_ascend_forward_context(
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
# TODO(yuzhup): integrate moe weight prefetch method
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
@@ -65,7 +66,11 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
|
||||
@@ -78,6 +78,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
if fusion and not dynamic_eplb:
|
||||
|
||||
@@ -1,83 +1,112 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
enable: bool = False
|
||||
prefetch_ratio: dict = field(default_factory=dict)
|
||||
linear_prefix_map: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
prefix: ratio
|
||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||
}
|
||||
|
||||
assert self.module_name in SUPPORTED_MODULES, (
|
||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||
)
|
||||
|
||||
if self.module_name in SUPPORTED_MODULES:
|
||||
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||
|
||||
|
||||
class WeightPrefetchMethod:
|
||||
"""
|
||||
Unified weight prefetch method.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.attn = ModuleWeightPrefetchConfig(
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}),
|
||||
linear_prefix_map={
|
||||
AscendQKVParallelLinear.__name__: "qkv",
|
||||
AscendRowParallelLinear.__name__: "o",
|
||||
})
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, layer_cls_name: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=start_flag,
|
||||
max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
|
||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
offset: int = 0,
|
||||
*,
|
||||
enabled: bool = True) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
input_size = inputs.element_size() * inputs.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
enable: bool = False
|
||||
is_active_this_forward: bool = False
|
||||
prefetch_ratio: dict = field(default_factory=dict)
|
||||
linear_prefix_map: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
prefix: ratio
|
||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||
}
|
||||
|
||||
assert self.module_name in SUPPORTED_MODULES, (
|
||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||
)
|
||||
|
||||
if self.module_name in SUPPORTED_MODULES:
|
||||
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||
|
||||
|
||||
class WeightPrefetchMethod:
|
||||
"""
|
||||
Unified weight prefetch method.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.attn = ModuleWeightPrefetchConfig(
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}),
|
||||
linear_prefix_map={
|
||||
AscendQKVParallelLinear.__name__: "qkv",
|
||||
AscendRowParallelLinear.__name__: "o",
|
||||
})
|
||||
self.moe = ModuleWeightPrefetchConfig(
|
||||
module_name="moe",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"moe", {}))
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, layer_cls_name: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=start_flag,
|
||||
max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
|
||||
self.moe.is_active_this_forward = hidden_states.shape[
|
||||
0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
weight = forward_context.model_instance.model.layers[
|
||||
forward_context.layer_idx].mlp.experts.w13_weight
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=None,
|
||||
max_weight_size=int(weight_size))
|
||||
forward_context.layer_idx += 1
|
||||
|
||||
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
|
||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
offset: int = 0,
|
||||
*,
|
||||
enabled: bool = True) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
input_size = inputs.element_size() * inputs.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
||||
|
||||
Reference in New Issue
Block a user