[Op] DeepSeekV3.2 support bmm_transpose operator (#4631)
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -484,14 +484,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.local_num_heads,
|
||||
|
||||
Reference in New Issue
Block a user