Replace pad with cat for better performance (#11388)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-10-10 12:03:17 +08:00
committed by GitHub
parent 70fbb3adf6
commit b5044fbf12
5 changed files with 5 additions and 5 deletions

View File

@@ -323,7 +323,7 @@ class DotsVisionTransformer(PreTrainedModel):
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
for blk in self.blocks:
hidden_states = blk(