Replace pad with cat for better performance (#11388)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -323,7 +323,7 @@ class DotsVisionTransformer(PreTrainedModel):
|
|||||||
dim=0,
|
dim=0,
|
||||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
hidden_states = blk(
|
hidden_states = blk(
|
||||||
|
|||||||
@@ -434,7 +434,7 @@ class Glm4vVisionModel(nn.Module):
|
|||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
x = self.embeddings(
|
x = self.embeddings(
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
.to(device=x.device, dtype=torch.int32),
|
.to(device=x.device, dtype=torch.int32),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|||||||
@@ -407,7 +407,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
cu_seqlens = torch.repeat_interleave(
|
cu_seqlens = torch.repeat_interleave(
|
||||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||||
).cumsum(dim=0, dtype=torch.int32)
|
).cumsum(dim=0, dtype=torch.int32)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
# transformers
|
# transformers
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|||||||
@@ -458,7 +458,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
|||||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
|
||||||
|
|
||||||
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
# max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user