[main] v_proj combining transpose and matmul (#3545)
### What this PR does / why we need it? v_proj combining transpose and matmul ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: GDzhu1 <809721801@qq.com>
This commit is contained in:
@@ -559,12 +559,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||
self.W_UV,
|
||||
perm_x1=[1, 0, 2],
|
||||
perm_x2=[0, 1, 2],
|
||||
perm_y=[1, 0, 2])
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
|
||||
# Return `ql_nope`, `q_pe`
|
||||
|
||||
Reference in New Issue
Block a user