【Bugfix】bugfix_for_bmm_transpose (#4899)
The bmm_transpose operator in version 3.2 is only used in the decoding stage due to shape limitations.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: ChrisGelhLan <33011886+xlan-huawei@users.noreply.github.com>
This commit is contained in:
@@ -490,9 +490,11 @@ class AscendSFAImpl(MLAAttentionImpl):
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
forward_context = get_forward_context()
|
||||
if x.dtype in [torch.float16, torch.bfloat16] \
|
||||
and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") \
|
||||
and not self.enable_sfa_cp:
|
||||
and not self.enable_sfa_cp \
|
||||
and not forward_context.with_prefill:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
b, _, _ = x.shape
|
||||
res = torch.empty((b, self.num_heads, self.v_head_dim),
|
||||
|
||||
Reference in New Issue
Block a user