Tiny fix cutlass_mla_get_workspace_size stub incorrect signature (#7057)
This commit is contained in:
@@ -37,7 +37,7 @@ void cutlass_mla_decode(
|
|||||||
torch::Tensor const& workspace) {
|
torch::Tensor const& workspace) {
|
||||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||||
}
|
}
|
||||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
|
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
|||||||
Reference in New Issue
Block a user