[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
This commit is contained in:
@@ -108,7 +108,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
PAGE_SIZE,
|
||||
)
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
max_seqlen_pad * PAGE_SIZE, bs
|
||||
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
||||
)
|
||||
workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
@@ -138,7 +138,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
cuda_graph_kv_indices = block_kv_indices
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
||||
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
|
||||
)
|
||||
self.cuda_graph_mla_workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
@@ -280,6 +280,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
page_table=self.forward_metadata.block_kv_indices,
|
||||
workspace=self.forward_metadata.workspace,
|
||||
num_kv_splits=1,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
@@ -95,6 +95,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||
seq_lens,
|
||||
block_table,
|
||||
workspace,
|
||||
1.44,
|
||||
num_kv_splits,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
|
||||
@@ -36,7 +36,8 @@ void cutlass_mla_decode(
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace) {
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
@@ -98,6 +99,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
@@ -115,10 +117,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
// the scale is based on the non-absorbed sizes, change as appropriate
|
||||
// we can't determine this parameter from the info we have, it's an input
|
||||
int D_non_latent = 128;
|
||||
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
@@ -183,11 +182,12 @@ void runMla(
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
@@ -215,6 +215,7 @@ void cutlass_mla_decode(
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
@@ -228,13 +229,13 @@ void cutlass_mla_decode(
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
||||
m.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
|
||||
"page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()");
|
||||
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
|
||||
|
||||
|
||||
8
sgl-kernel/include/sgl_kernel_ops.h
Executable file → Normal file
8
sgl-kernel/include/sgl_kernel_ops.h
Executable file → Normal file
@@ -111,9 +111,13 @@ void cutlass_mla_decode(
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits = -1);
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
int64_t cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
|
||||
int64_t max_seq_len,
|
||||
int64_t num_batches,
|
||||
int64_t sm_count = 0,
|
||||
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
/*
|
||||
* From csrc/elementwise
|
||||
*/
|
||||
|
||||
@@ -58,7 +58,8 @@ def cutlass_mla_decode(
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_kv_splits: int = -1,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> torch.Tensor:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
@@ -118,13 +119,17 @@ def cutlass_mla_decode(
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1
|
||||
max_seq_len: int,
|
||||
num_batches: int,
|
||||
sm_count: int = 0,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> int:
|
||||
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
||||
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_cutlass_mla_decode(
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out = cutlass_mla_decode(
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, scale, num_kv_splits
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
Reference in New Issue
Block a user