[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)

This commit is contained in:
JieXin Liang
2025-03-16 15:02:47 +08:00
committed by GitHub
parent 81f431eded
commit 1a3fa75f2f
14 changed files with 20 additions and 20 deletions

View File

@@ -148,7 +148,7 @@ def lightning_attention_decode_naive(q, k, v, past_kv, slope):
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv

View File

@@ -24,7 +24,7 @@ def naive_lightning_attention_decode(q, k, v, past_kv, slope):
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
output = torch.cat(output, dim=-2)
return output.to(original_dtype), kv