From b2c856692092ddc8999520cafaf610cf9db8c8cd Mon Sep 17 00:00:00 2001 From: Zheng Wengang Date: Wed, 15 Oct 2025 13:16:49 +0800 Subject: [PATCH] [BugFix][Qwen3-VL]: fix cu_seqlens in qwen3-vl (#11458) --- python/sglang/srt/models/qwen3_vl.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 0f8995307..0db6541d3 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -452,13 +452,15 @@ class Qwen3_VisionTransformer(nn.Module): position_embeddings = (emb.cos(), emb.sin()) # compute cu_seqlens + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0) cu_seqlens = torch.cat( [ - torch.tensor([0], device=grid_thw.device), - (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0), + torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), + cu_seqlens.to(torch.int32), ] ) - cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) # max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) x = x.unsqueeze(1)