adapt the mla_v1 with the mla_preprocess kernel (#3397)
### What this PR does / why we need it? This pull request integrates a new `mla_preprocess` kernel to create an optimized path for MLA (Multi-Layer Attention) decode operations on Ascend hardware, controlled by an environment flag. The changes include new utility functions for weight transformation, a method to prepare weights for the fused kernel, and logic to route decode-only batches to this new path. My review identified a critical bug in the `transdata` utility function where padding dimensions are swapped, which will lead to incorrect tensor shapes and kernel failures. Additionally, I've pointed out a high-severity maintainability issue in the trans_rope_weight function, which modifies its input in-place, and I have provided a pure-function alternative. ### Does this PR introduce _any_ user-facing change? No user-facing changes by default. User can enable the `mla_preprocess` kernel in model by enable the env-var `VLLM_ASCEND_ENABLE_MLAPO`. ### How was this patch tested? Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
@@ -153,3 +154,39 @@ def version_check():
|
||||
if full_date >= "20250919":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def round_up(val: int, align: int) -> int:
|
||||
if align == 0:
|
||||
return 0
|
||||
return -(val // -align) * align
|
||||
|
||||
|
||||
def trans_rope_weight(weight, rope_dim):
|
||||
if rope_dim == 0:
|
||||
return weight.contiguous()
|
||||
nope_part = weight[..., :-rope_dim, :]
|
||||
rope_part = weight[..., -rope_dim:, :]
|
||||
reordered_rope_part = torch.cat(
|
||||
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
|
||||
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
|
||||
|
||||
|
||||
def transdata(nd_mat, block_size: tuple = (16, 16)):
|
||||
r = round_up(nd_mat.shape[0], block_size[0])
|
||||
c = round_up(nd_mat.shape[1], block_size[1])
|
||||
r_pad = r - nd_mat.shape[0]
|
||||
c_pad = c - nd_mat.shape[1]
|
||||
nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad))
|
||||
nz_mat = torch.permute(
|
||||
torch.reshape(
|
||||
nd_mat,
|
||||
(r // block_size[0], block_size[0], c // block_size[1],
|
||||
block_size[1]),
|
||||
),
|
||||
[2, 0, 1, 3],
|
||||
)
|
||||
nz_mat = torch.reshape(
|
||||
nz_mat,
|
||||
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
|
||||
return nz_mat
|
||||
|
||||
Reference in New Issue
Block a user