[2/N][Feat] Attention and MoE weight prefetch in Qwen3MoE models (#3203)
### What this PR does / why we need it?
- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `gate_up_proj.weight` in quantized Attention modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency
### Does this PR introduce _any_ user-facing change?
Add a new config in `--additional-config` for configuration:
```json
{
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"moe": {
"gate_up": 0.8
},
},
},
}
```
This feature is enabled by default, and can be disabled through this
configuration
### How was this patch tested?
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
---------
Signed-off-by: yuzhup <15705211260@163.com>
This commit is contained in:
@@ -291,7 +291,9 @@ def test_select_experts(
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.moe.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk:
|
||||
) as mock_native_grouped_topk, \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
|
||||
@@ -325,7 +327,9 @@ def test_select_experts(
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with pytest.raises(ValueError,
|
||||
with patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
return_value=MagicMock(weight_prefetch_method=MagicMock())), \
|
||||
pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 8, device=device),
|
||||
|
||||
@@ -92,14 +92,16 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(
|
||||
16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
mock_weight_prefetch_method = MagicMock()
|
||||
mock_forward_context_obj = MagicMock(
|
||||
moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False,
|
||||
weight_prefetch_method=mock_weight_prefetch_method)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
@@ -132,7 +134,9 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None):
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
|
||||
@@ -755,6 +755,14 @@ class TestSelectExperts(TestBase):
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
patcher = patch(
|
||||
'vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
||||
return_value=self.mock_ctx)
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
|
||||
Reference in New Issue
Block a user