[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)

This commit is contained in:
JieXin Liang
2025-03-16 15:02:47 +08:00
committed by GitHub
parent 81f431eded
commit 1a3fa75f2f
14 changed files with 20 additions and 20 deletions

View File

@@ -235,7 +235,7 @@ class MiniMaxText01LightningAttention(nn.Module):
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")

View File

@@ -403,7 +403,7 @@ class MiniMaxText01LightningAttention(nn.Module):
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
# reshape
output = rearrange(output, "b h n d -> b n (h d)")
# normalize