[Op] DeepSeekV3.2 support bmm_transpose operator (#4631)

### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
ZYang6263
2025-12-08 14:03:38 +08:00
committed by GitHub
parent 0b65ac6c4b
commit a433f3280a

View File

@@ -484,14 +484,15 @@ class AscendSFAImpl(MLAAttentionImpl):
self._process_weights_for_fused_mlapo(act_dtype) self._process_weights_for_fused_mlapo(act_dtype)
def _v_up_proj(self, x): def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: if x.dtype in [torch.float16, torch.bfloat16] \
x = x.view(-1, self.local_num_heads, self.kv_lora_rank) and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
x = torch_npu.npu_transpose_batchmatmul(x, x = x.view(-1, self.num_heads, self.kv_lora_rank)
self.W_UV, b, _, _ = x.shape
perm_x1=[1, 0, 2], res = torch.empty((b, self.num_heads, self.v_head_dim),
perm_x2=[0, 1, 2], dtype=x.dtype,
perm_y=[1, 0, 2]) device=x.device)
x = x.reshape(-1, self.local_num_heads * self.v_head_dim) torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
x = res.reshape(-1, self.num_heads * self.v_head_dim)
else: else:
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.local_num_heads, x = x.view(-1, self.local_num_heads,