[2/N][Feat] Attention and MoE weight prefetch in Qwen3MoE models (#3203)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `gate_up_proj.weight` in quantized Attention modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"moe": {
"gate_up": 0.8
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -73,10 +73,10 @@ ascend_scheduler_config also support the options from [vllm scheduler config](ht
|
|||||||
|
|
||||||
**weight_prefetch_config**
|
**weight_prefetch_config**
|
||||||
|
|
||||||
| Name | Type | Default | Description |
|
| Name | Type | Default | Description |
|
||||||
|------------------|------|------------------------------------|------------------------------------|
|
|------------------|------|-------------------------------------------------------------|------------------------------------|
|
||||||
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
|
||||||
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. |
|
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weights. |
|
||||||
|
|
||||||
### Example
|
### Example
|
||||||
|
|
||||||
@@ -104,6 +104,9 @@ An example of additional configuration is as follows:
|
|||||||
"qkv": 1.0,
|
"qkv": 1.0,
|
||||||
"o": 1.0,
|
"o": 1.0,
|
||||||
},
|
},
|
||||||
|
"moe": {
|
||||||
|
"gate_up": 0.8
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"multistream_overlap_shared_expert": True,
|
"multistream_overlap_shared_expert": True,
|
||||||
|
|||||||
@@ -291,7 +291,9 @@ def test_select_experts(
|
|||||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||||
|
|
||||||
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
|
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
|
||||||
) as mock_native_grouped_topk:
|
) as mock_native_grouped_topk, \
|
||||||
|
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||||
|
return_value=MagicMock(weight_prefetch_method=MagicMock())):
|
||||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||||
x)
|
x)
|
||||||
|
|
||||||
@@ -325,7 +327,9 @@ def test_select_experts(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("device", DEVICE)
|
@pytest.mark.parametrize("device", DEVICE)
|
||||||
def test_select_experts_invalid_scoring_func(device: str):
|
def test_select_experts_invalid_scoring_func(device: str):
|
||||||
with pytest.raises(ValueError,
|
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||||
|
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
|
||||||
|
pytest.raises(ValueError,
|
||||||
match="Unsupported scoring function: invalid"):
|
match="Unsupported scoring function: invalid"):
|
||||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||||
router_logits=torch.randn(1, 8, device=device),
|
router_logits=torch.randn(1, 8, device=device),
|
||||||
|
|||||||
@@ -92,14 +92,16 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
|
|
||||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
mock_weight_prefetch_method = MagicMock()
|
||||||
moe_comm_type=MoECommType.MC2,
|
mock_forward_context_obj = MagicMock(
|
||||||
max_tokens_across_dp=10,
|
moe_comm_method=mock_moe_comm_method,
|
||||||
dp_metadata=dp_metadata,
|
moe_comm_type=MoECommType.MC2,
|
||||||
mc2_mask=torch.zeros(
|
max_tokens_across_dp=10,
|
||||||
16, dtype=torch.bool),
|
dp_metadata=dp_metadata,
|
||||||
padded_num_tokens=16,
|
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||||
with_quant=False)
|
padded_num_tokens=16,
|
||||||
|
with_quant=False,
|
||||||
|
weight_prefetch_method=mock_weight_prefetch_method)
|
||||||
|
|
||||||
with patch('torch.distributed.get_rank', return_value=0), \
|
with patch('torch.distributed.get_rank', return_value=0), \
|
||||||
patch('torch.distributed.get_world_size', return_value=4), \
|
patch('torch.distributed.get_world_size', return_value=4), \
|
||||||
@@ -132,7 +134,9 @@ def mock_dist_env(mocker: MockerFixture):
|
|||||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||||
return_value=None), \
|
return_value=None), \
|
||||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||||
return_value=None):
|
return_value=None), \
|
||||||
|
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||||
|
return_value=mock_forward_context_obj):
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
'mock_forward_context_obj': mock_forward_context_obj,
|
'mock_forward_context_obj': mock_forward_context_obj,
|
||||||
|
|||||||
@@ -755,6 +755,14 @@ class TestSelectExperts(TestBase):
|
|||||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||||
|
|
||||||
|
self.mock_ctx = MagicMock()
|
||||||
|
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||||
|
patcher = patch(
|
||||||
|
'vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||||
|
return_value=self.mock_ctx)
|
||||||
|
self.addCleanup(patcher.stop)
|
||||||
|
patcher.start()
|
||||||
|
|
||||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||||
def test_softmax_scoring(self, mock_topk):
|
def test_softmax_scoring(self, mock_topk):
|
||||||
"""Test softmax scoring function"""
|
"""Test softmax scoring function"""
|
||||||
|
|||||||
@@ -216,6 +216,9 @@ class WeightPrefetchConfig:
|
|||||||
"qkv": 1.0,
|
"qkv": 1.0,
|
||||||
"o": 1.0,
|
"o": 1.0,
|
||||||
},
|
},
|
||||||
|
"moe": {
|
||||||
|
"gate_up": 0.8
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, weight_prefetch_config: dict):
|
def __init__(self, weight_prefetch_config: dict):
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ def set_ascend_forward_context(
|
|||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
forward_context.prefetch_mlp_gate_up_proj = False
|
||||||
forward_context.prefetch_mlp_down_proj = False
|
forward_context.prefetch_mlp_down_proj = False
|
||||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||||
# TODO(yuzhup): integrate moe weight prefetch method
|
forward_context.model_instance = model_instance
|
||||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||||
|
|
||||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
|
||||||
def return_row_idx(hidden_states, top_k):
|
def return_row_idx(hidden_states, top_k):
|
||||||
@@ -65,7 +66,11 @@ def select_experts(hidden_states: torch.Tensor,
|
|||||||
topk_weights: router weights of shape (num_tokens, top_k).
|
topk_weights: router weights of shape (num_tokens, top_k).
|
||||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||||
"""
|
"""
|
||||||
|
# prefetch w1_w3_proj.weight preprocess
|
||||||
|
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||||
|
if weight_prefetch_method:
|
||||||
|
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||||
|
hidden_states, "gate_up")
|
||||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -78,6 +78,10 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
|||||||
bias1, bias2 = None, None
|
bias1, bias2 = None, None
|
||||||
_output_dtype = w2_scale.dtype
|
_output_dtype = w2_scale.dtype
|
||||||
|
|
||||||
|
weight_prefetch_method = get_forward_context().weight_prefetch_method
|
||||||
|
if weight_prefetch_method:
|
||||||
|
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||||
|
hidden_states)
|
||||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||||
if w1_scale_bias is None and is_mc2:
|
if w1_scale_bias is None and is_mc2:
|
||||||
if fusion and not dynamic_eplb:
|
if fusion and not dynamic_eplb:
|
||||||
|
|||||||
@@ -1,83 +1,112 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
|
||||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||||
AscendRowParallelLinear)
|
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||||
|
AscendRowParallelLinear)
|
||||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
|
||||||
|
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||||
|
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||||
@dataclass
|
|
||||||
class ModuleWeightPrefetchConfig:
|
|
||||||
module_name: str
|
@dataclass
|
||||||
enable: bool = False
|
class ModuleWeightPrefetchConfig:
|
||||||
prefetch_ratio: dict = field(default_factory=dict)
|
module_name: str
|
||||||
linear_prefix_map: dict = field(default_factory=dict)
|
enable: bool = False
|
||||||
|
is_active_this_forward: bool = False
|
||||||
def __post_init__(self) -> None:
|
prefetch_ratio: dict = field(default_factory=dict)
|
||||||
self.prefetch_ratio = {
|
linear_prefix_map: dict = field(default_factory=dict)
|
||||||
prefix: ratio
|
|
||||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
def __post_init__(self) -> None:
|
||||||
}
|
self.prefetch_ratio = {
|
||||||
|
prefix: ratio
|
||||||
assert self.module_name in SUPPORTED_MODULES, (
|
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
}
|
||||||
)
|
|
||||||
|
assert self.module_name in SUPPORTED_MODULES, (
|
||||||
if self.module_name in SUPPORTED_MODULES:
|
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||||
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
)
|
||||||
|
|
||||||
|
if self.module_name in SUPPORTED_MODULES:
|
||||||
class WeightPrefetchMethod:
|
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||||
"""
|
|
||||||
Unified weight prefetch method.
|
|
||||||
"""
|
class WeightPrefetchMethod:
|
||||||
|
"""
|
||||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
Unified weight prefetch method.
|
||||||
self.attn = ModuleWeightPrefetchConfig(
|
"""
|
||||||
module_name="attn",
|
|
||||||
enable=weight_prefetch_config.enabled,
|
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
self.attn = ModuleWeightPrefetchConfig(
|
||||||
"attn", {}),
|
module_name="attn",
|
||||||
linear_prefix_map={
|
enable=weight_prefetch_config.enabled,
|
||||||
AscendQKVParallelLinear.__name__: "qkv",
|
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||||
AscendRowParallelLinear.__name__: "o",
|
"attn", {}),
|
||||||
})
|
linear_prefix_map={
|
||||||
|
AscendQKVParallelLinear.__name__: "qkv",
|
||||||
def maybe_prefetch_attn_weight_preprocess(
|
AscendRowParallelLinear.__name__: "o",
|
||||||
self, layer_cls_name: str, weight: torch.Tensor,
|
})
|
||||||
start_flag: torch.Tensor) -> None:
|
self.moe = ModuleWeightPrefetchConfig(
|
||||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
module_name="moe",
|
||||||
return
|
enable=weight_prefetch_config.enabled,
|
||||||
|
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||||
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
"moe", {}))
|
||||||
weight_size = weight.data.element_size() * weight.data.numel(
|
|
||||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
def maybe_prefetch_attn_weight_preprocess(
|
||||||
|
self, layer_cls_name: str, weight: torch.Tensor,
|
||||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
start_flag: torch.Tensor) -> None:
|
||||||
start_flag=start_flag,
|
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||||
max_weight_size=int(weight_size))
|
return
|
||||||
|
|
||||||
def maybe_prefetch_attn_weight_postprocess(
|
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
||||||
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
weight_size = weight.data.element_size() * weight.data.numel(
|
||||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||||
return
|
|
||||||
|
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
start_flag=start_flag,
|
||||||
|
max_weight_size=int(weight_size))
|
||||||
|
|
||||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
def maybe_prefetch_attn_weight_postprocess(
|
||||||
dependency: torch.Tensor,
|
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||||
max_size: int = 0,
|
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||||
offset: int = 0,
|
return
|
||||||
*,
|
|
||||||
enabled: bool = True) -> None:
|
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||||
if not enabled:
|
|
||||||
return
|
def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
|
||||||
input_size = inputs.element_size() * inputs.numel()
|
self.moe.is_active_this_forward = hidden_states.shape[
|
||||||
if max_size <= 0 or max_size > input_size:
|
0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
|
||||||
max_size = input_size
|
if not self.moe.is_active_this_forward:
|
||||||
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
return
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
weight = forward_context.model_instance.model.layers[
|
||||||
|
forward_context.layer_idx].mlp.experts.w13_weight
|
||||||
|
weight_size = weight.data.element_size() * weight.data.numel(
|
||||||
|
) * self.moe.prefetch_ratio.get(prefix, 0)
|
||||||
|
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||||
|
start_flag=None,
|
||||||
|
max_weight_size=int(weight_size))
|
||||||
|
forward_context.layer_idx += 1
|
||||||
|
|
||||||
|
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||||
|
if not self.moe.is_active_this_forward:
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||||
|
dependency: torch.Tensor,
|
||||||
|
max_size: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
*,
|
||||||
|
enabled: bool = True) -> None:
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
input_size = inputs.element_size() * inputs.numel()
|
||||||
|
if max_size <= 0 or max_size > input_size:
|
||||||
|
max_size = input_size
|
||||||
|
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
||||||
|
|||||||
Reference in New Issue
Block a user