[OPS] add bmm_transpose ops (#3990)

### What this PR does / why we need it?
Add a new fusion ops to custom_op, which can cobime the torch.bmm() and
transpsose to achieve better peformance. This ops is used in mla_v1 to
replace the bmm and transpose

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.11.2

---------

Signed-off-by: hust17yixuan <303660421@qq.com>
This commit is contained in:
Wang Yixuan
2025-12-01 09:09:51 +08:00
committed by GitHub
parent bc67696a02
commit c68ddc11ce
15 changed files with 1737 additions and 14 deletions

View File

@@ -887,15 +887,16 @@ class AscendMLAImpl(MLAAttentionImpl):
).device_group if self.tp_size > 1 else None
def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[
1] < 65536 and not self.dcp_size * self.pcp_size > 1:
if x.dtype in [torch.float16, torch.bfloat16] \
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
and not self.dcp_size * self.pcp_size > 1:
x = x.view(-1, self.num_heads, self.kv_lora_rank)
x = torch_npu.npu_transpose_batchmatmul(x,
self.W_UV,
perm_x1=[1, 0, 2],
perm_x2=[0, 1, 2],
perm_y=[1, 0, 2])
x = x.reshape(-1, self.num_heads * self.v_head_dim)
b, _, _ = x.shape
res = torch.empty((b, self.num_heads, self.v_head_dim),
dtype=x.dtype,
device=x.device)
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else:
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)