bugfix for qwen2_5_vl (#805)
### What this PR does / why we need it? the interface of qwen2.5vl changes from column linear to qkv linear, this makes our weight pad func become abnormal, thus we optimize split_qkv func to fix this bug. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? with CI Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -70,6 +70,19 @@ class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):
|
||||
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
|
||||
self.hidden_size_per_attention_head = MAX_PAD_SIZE
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user