[Fix] use torch.cat instead of torch.concat to prevent entering the Autograd backends. (#4466)
This commit is contained in:
@@ -148,7 +148,7 @@ def lightning_attention_decode_naive(q, k, v, past_kv, slope):
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def naive_lightning_attention_decode(q, k, v, past_kv, slope):
|
||||
kv.to(torch.float32),
|
||||
)
|
||||
output.append(qkv)
|
||||
output = torch.concat(output, dim=-2)
|
||||
output = torch.cat(output, dim=-2)
|
||||
|
||||
return output.to(original_dtype), kv
|
||||
|
||||
|
||||
Reference in New Issue
Block a user