[Op] DeepSeekV3.2 support bmm_transpose operator (#4631)
### What this PR does / why we need it?
DeepSeekV3.2 support bmm_transpose operator.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: ZYang6263 <zy626375@gmail.com>
Signed-off-by: ZYang6263 <50876451+ZYang6263@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -484,14 +484,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self._process_weights_for_fused_mlapo(act_dtype)
|
self._process_weights_for_fused_mlapo(act_dtype)
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||||
x = x.view(-1, self.local_num_heads, self.kv_lora_rank)
|
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose"):
|
||||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
self.W_UV,
|
b, _, _ = x.shape
|
||||||
perm_x1=[1, 0, 2],
|
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||||
perm_x2=[0, 1, 2],
|
dtype=x.dtype,
|
||||||
perm_y=[1, 0, 2])
|
device=x.device)
|
||||||
x = x.reshape(-1, self.local_num_heads * self.v_head_dim)
|
torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res)
|
||||||
|
x = res.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
else:
|
else:
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.local_num_heads,
|
x = x.view(-1, self.local_num_heads,
|
||||||
|
|||||||
Reference in New Issue
Block a user