[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
Ruri
2025-10-09 20:38:39 +08:00
committed by GitHub
parent 23db56a340
commit ff37575936
13 changed files with 264 additions and 69 deletions

View File

@@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
mock_quant_per_tensor):
mock_quant_per_tensor,
mock_get_forward_context):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
mock_forward_context = MagicMock()
mock_get_forward_context.return_value = mock_forward_context
mock_weight_prefetch_method = MagicMock()
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
x = torch.randn(32, 128)
bias = torch.randn(256)
mock_quant_per_tensor.return_value = torch.randint(-128,