[BugFix][Qwen3-VL]: fix cu_seqlens in qwen3-vl (#11458)
This commit is contained in:
@@ -452,13 +452,15 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
position_embeddings = (emb.cos(), emb.sin())
|
position_embeddings = (emb.cos(), emb.sin())
|
||||||
|
|
||||||
# compute cu_seqlens
|
# compute cu_seqlens
|
||||||
|
cu_seqlens = torch.repeat_interleave(
|
||||||
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
|
).cumsum(dim=0)
|
||||||
cu_seqlens = torch.cat(
|
cu_seqlens = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([0], device=grid_thw.device),
|
torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device),
|
||||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
cu_seqlens.to(torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
|
||||||
|
|
||||||
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user