[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)
This commit is contained in:
@@ -235,7 +235,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
# reshape
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
|
||||
@@ -403,7 +403,7 @@ class MiniMaxText01LightningAttention(nn.Module):
|
||||
"... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype)
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
output = torch.cat(output, dim=-2)
|
||||
# reshape
|
||||
output = rearrange(output, "b h n d -> b n (h d)")
|
||||
# normalize
|
||||
|
||||
Reference in New Issue
Block a user