[BugFix] Fixed the bug that caused the transposematmul operator to report an error due to the shape being too large (#3578)
### What this PR does / why we need it? npu_transpose_batchmatmul has the problem that the shape being too large - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: GDzhu1 <809721801@qq.com>
This commit is contained in:
@@ -559,13 +559,21 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
else:
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# # Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
|
||||
Reference in New Issue
Block a user