refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
q = q.view(original_shape)
k = k.view(original_shape)
if self.use_qkv_parallel:
pass