[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
This commit is contained in:
8
sgl-kernel/include/sgl_kernel_ops.h
Executable file → Normal file
8
sgl-kernel/include/sgl_kernel_ops.h
Executable file → Normal file
@@ -111,9 +111,13 @@ void cutlass_mla_decode(
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits = -1);
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
|
||||
int64_t max_seq_len,
|
||||
int64_t num_batches,
|
||||
int64_t sm_count = 0,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user