refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
original_shape = q.shape
|
||||
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
||||
# [total_tokens, head, head_size]
|
||||
q = q.view(-1, head, self.head_size)
|
||||
k = k.view(-1, head, self.head_size)
|
||||
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
||||
|
||||
q = q.view(original_shape)
|
||||
k = k.view(original_shape)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user