[main] v_proj combining transpose and matmul (#3545)
### What this PR does / why we need it? v_proj combining transpose and matmul ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: GDzhu1 <809721801@qq.com>
This commit is contained in:
@@ -559,12 +559,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = torch_npu.npu_transpose_batchmatmul(x,
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
self.W_UV,
|
||||||
x = torch.bmm(x, self.W_UV)
|
perm_x1=[1, 0, 2],
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
perm_x2=[0, 1, 2],
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
perm_y=[1, 0, 2])
|
||||||
|
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
# Return `ql_nope`, `q_pe`
|
||||||
|
|||||||
Reference in New Issue
Block a user