Remove 200us slow concat kernel (part 1: kernel) (#7145)
This commit is contained in:
@@ -38,6 +38,7 @@ configs = list(itertools.product(bs_range, qlen_range))
|
|||||||
)
|
)
|
||||||
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
||||||
d = 576
|
d = 576
|
||||||
|
dn = 64
|
||||||
dv = 512
|
dv = 512
|
||||||
|
|
||||||
h_q_map = {
|
h_q_map = {
|
||||||
@@ -63,7 +64,11 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
|||||||
pack_factor = 128 // block_size
|
pack_factor = 128 // block_size
|
||||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||||
|
|
||||||
q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
|
qn = (
|
||||||
|
torch.randn(h_q, batch_size, d - dn, dtype=torch.bfloat16, device="cuda")
|
||||||
|
* 100.0
|
||||||
|
)
|
||||||
|
qr = torch.randn(batch_size, h_q, dn, dtype=torch.bfloat16, device="cuda") * 100.0
|
||||||
block_table = torch.randint(
|
block_table = torch.randint(
|
||||||
0,
|
0,
|
||||||
batch_size * block_num,
|
batch_size * block_num,
|
||||||
@@ -84,16 +89,22 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
|
|||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
lambda: cutlass_mla_decode(
|
lambda: cutlass_mla_decode(
|
||||||
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
qn.transpose(0, 1),
|
||||||
|
qr,
|
||||||
|
kv_cache,
|
||||||
|
seq_lens,
|
||||||
|
block_table,
|
||||||
|
workspace,
|
||||||
|
num_kv_splits,
|
||||||
),
|
),
|
||||||
quantiles=quantiles,
|
quantiles=quantiles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
q_size = qn.numel() * qn.element_size() + qr.numel() * qr.element_size()
|
||||||
|
|
||||||
gbps = (
|
gbps = (
|
||||||
lambda ms: (
|
lambda ms: (
|
||||||
q.numel() * q.element_size()
|
q_size + q_size * dv / d + kv_cache.numel() * kv_cache.element_size()
|
||||||
+ q.numel() * q.element_size() * dv / d
|
|
||||||
+ kv_cache.numel() * kv_cache.element_size()
|
|
||||||
)
|
)
|
||||||
* 1e-9
|
* 1e-9
|
||||||
/ (ms * 1e-3)
|
/ (ms * 1e-3)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||||
@@ -30,7 +31,8 @@ limitations under the License.
|
|||||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||||
void cutlass_mla_decode(
|
void cutlass_mla_decode(
|
||||||
torch::Tensor const& out,
|
torch::Tensor const& out,
|
||||||
torch::Tensor const& q_nope_and_q_pe,
|
torch::Tensor const& q_nope,
|
||||||
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
torch::Tensor const& seq_lens,
|
torch::Tensor const& seq_lens,
|
||||||
torch::Tensor const& page_table,
|
torch::Tensor const& page_table,
|
||||||
@@ -91,16 +93,17 @@ struct MlaSm100 {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
typename T::Fmha::Arguments args_from_options(
|
typename T::Fmha::Arguments args_from_options(
|
||||||
at::Tensor const& out,
|
at::Tensor const& out,
|
||||||
at::Tensor const& q_nope_and_q_pe,
|
at::Tensor const& q_nope,
|
||||||
|
at::Tensor const& q_pe,
|
||||||
at::Tensor const& kv_c_and_k_pe_cache,
|
at::Tensor const& kv_c_and_k_pe_cache,
|
||||||
at::Tensor const& seq_lens,
|
at::Tensor const& seq_lens,
|
||||||
at::Tensor const& page_table,
|
at::Tensor const& page_table,
|
||||||
int64_t num_kv_splits) {
|
int64_t num_kv_splits) {
|
||||||
cutlass::KernelHardwareInfo hw_info;
|
cutlass::KernelHardwareInfo hw_info;
|
||||||
hw_info.device_id = q_nope_and_q_pe.device().index();
|
hw_info.device_id = q_nope.device().index();
|
||||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||||
|
|
||||||
int batches = q_nope_and_q_pe.sizes()[0];
|
int batches = q_nope.sizes()[0];
|
||||||
int page_count_per_seq = page_table.sizes()[1];
|
int page_count_per_seq = page_table.sizes()[1];
|
||||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||||
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
using StrideO = typename T::StrideO;
|
using StrideO = typename T::StrideO;
|
||||||
using StrideLSE = typename T::StrideLSE;
|
using StrideLSE = typename T::StrideLSE;
|
||||||
|
|
||||||
StrideQ stride_Q = cute::make_tuple(
|
StrideQ stride_Q_nope = cute::make_tuple(
|
||||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
|
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||||
|
StrideQ stride_Q_pe = cute::make_tuple(
|
||||||
|
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||||
|
|
||||||
StrideK stride_C = cute::make_tuple(
|
StrideK stride_C = cute::make_tuple(
|
||||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||||
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
using Element = typename T::Element;
|
using Element = typename T::Element;
|
||||||
using ElementOut = typename T::ElementOut;
|
using ElementOut = typename T::ElementOut;
|
||||||
using ElementAcc = typename T::ElementAcc;
|
using ElementAcc = typename T::ElementAcc;
|
||||||
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
|
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||||
|
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||||
typename T::Fmha::Arguments arguments{
|
typename T::Fmha::Arguments arguments{
|
||||||
problem_shape,
|
problem_shape,
|
||||||
{scale,
|
{scale,
|
||||||
Q_ptr,
|
Q_nope_ptr,
|
||||||
stride_Q,
|
stride_Q_nope,
|
||||||
Q_ptr + D_latent,
|
Q_pe_ptr,
|
||||||
stride_Q,
|
stride_Q_pe,
|
||||||
C_ptr,
|
C_ptr,
|
||||||
stride_C,
|
stride_C,
|
||||||
C_ptr + D_latent,
|
C_ptr + D_latent,
|
||||||
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||||
void runMla(
|
void runMla(
|
||||||
at::Tensor const& out,
|
at::Tensor const& out,
|
||||||
at::Tensor const& q_nope_and_q_pe,
|
at::Tensor const& q_nope,
|
||||||
|
at::Tensor const& q_pe,
|
||||||
at::Tensor const& kv_c_and_k_pe_cache,
|
at::Tensor const& kv_c_and_k_pe_cache,
|
||||||
at::Tensor const& seq_lens,
|
at::Tensor const& seq_lens,
|
||||||
at::Tensor const& page_table,
|
at::Tensor const& page_table,
|
||||||
@@ -179,7 +187,7 @@ void runMla(
|
|||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||||
typename MlaSm100Type::Fmha fmha;
|
typename MlaSm100Type::Fmha fmha;
|
||||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
|
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
|
||||||
|
|
||||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||||
|
|
||||||
@@ -201,15 +209,16 @@ void runMla(
|
|||||||
|
|
||||||
void cutlass_mla_decode(
|
void cutlass_mla_decode(
|
||||||
torch::Tensor const& out,
|
torch::Tensor const& out,
|
||||||
torch::Tensor const& q_nope_and_q_pe,
|
torch::Tensor const& q_nope,
|
||||||
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
torch::Tensor const& seq_lens,
|
torch::Tensor const& seq_lens,
|
||||||
torch::Tensor const& page_table,
|
torch::Tensor const& page_table,
|
||||||
torch::Tensor const& workspace,
|
torch::Tensor const& workspace,
|
||||||
int64_t num_kv_splits) {
|
int64_t num_kv_splits) {
|
||||||
auto in_dtype = q_nope_and_q_pe.dtype();
|
auto in_dtype = q_nope.dtype();
|
||||||
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
|
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||||
|
|
||||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||||
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
|
|||||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||||
if (in_dtype == at::ScalarType::Half) {
|
if (in_dtype == at::ScalarType::Half) {
|
||||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||||
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
||||||
m.def(
|
m.def(
|
||||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||||
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
|
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
|
||||||
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||||
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
|
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);
|
||||||
|
|||||||
@@ -105,7 +105,8 @@ void merge_state_v2(
|
|||||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||||
void cutlass_mla_decode(
|
void cutlass_mla_decode(
|
||||||
torch::Tensor const& out,
|
torch::Tensor const& out,
|
||||||
torch::Tensor const& q_nope_and_q_pe,
|
torch::Tensor const& q_nope,
|
||||||
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
torch::Tensor const& seq_lens,
|
torch::Tensor const& seq_lens,
|
||||||
torch::Tensor const& page_table,
|
torch::Tensor const& page_table,
|
||||||
|
|||||||
@@ -52,34 +52,42 @@ def merge_state_v2(
|
|||||||
|
|
||||||
|
|
||||||
def cutlass_mla_decode(
|
def cutlass_mla_decode(
|
||||||
q_nope_and_q_pe: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
page_table: torch.Tensor,
|
page_table: torch.Tensor,
|
||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
num_kv_splits: int = -1,
|
num_kv_splits: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert (
|
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||||
q_nope_and_q_pe.ndim == 3
|
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||||
), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
|
|
||||||
assert (
|
assert (
|
||||||
kv_c_and_k_pe_cache.ndim == 3
|
kv_c_and_k_pe_cache.ndim == 3
|
||||||
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
||||||
B_q, H, D_q = q_nope_and_q_pe.shape
|
|
||||||
|
B_q, H, D_q_nope = q_nope.shape
|
||||||
|
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||||
|
assert (B_q == B_q_2) and (H == H_2)
|
||||||
|
|
||||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||||
|
|
||||||
D_latent = 512
|
D_latent = 512
|
||||||
D_rope = 64
|
D_rope = 64
|
||||||
assert D_q == D_ckv and D_q == D_latent + D_rope, (
|
assert D_q_nope == D_latent
|
||||||
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
|
assert D_q_pe == D_rope
|
||||||
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
|
assert D_ckv == D_latent + D_rope
|
||||||
)
|
|
||||||
MAX_HEADS = 128
|
MAX_HEADS = 128
|
||||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||||
if H < MAX_HEADS:
|
if H < MAX_HEADS:
|
||||||
q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q))
|
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||||
q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
|
q_nope_padded[:, :H] = q_nope
|
||||||
q_nope_and_q_pe = q_nope_and_q_pe_padded
|
q_nope = q_nope_padded
|
||||||
|
|
||||||
|
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||||
|
q_pe_padded[:, :H] = q_pe
|
||||||
|
q_pe = q_pe_padded
|
||||||
|
|
||||||
assert len(page_table.shape) == 2
|
assert len(page_table.shape) == 2
|
||||||
B_block_table, block_num = page_table.shape
|
B_block_table, block_num = page_table.shape
|
||||||
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
|
|||||||
assert block_num % (128 / PAGE_SIZE) == 0
|
assert block_num % (128 / PAGE_SIZE) == 0
|
||||||
|
|
||||||
# TODO(kaixih@nvidia): support fp8
|
# TODO(kaixih@nvidia): support fp8
|
||||||
assert q_nope_and_q_pe.dtype in (
|
assert q_nope.dtype in (
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.bfloat16,
|
torch.bfloat16,
|
||||||
), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}."
|
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||||
assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
|
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||||
f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
|
|
||||||
f"but got {kv_c_and_k_pe_cache.dtype}."
|
|
||||||
)
|
|
||||||
assert (
|
assert (
|
||||||
seq_lens.dtype == torch.int32
|
seq_lens.dtype == torch.int32
|
||||||
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||||
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
|
|||||||
page_table.dtype == torch.int32
|
page_table.dtype == torch.int32
|
||||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||||
|
|
||||||
out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent))
|
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||||
|
|
||||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||||
out,
|
out,
|
||||||
q_nope_and_q_pe,
|
q_nope,
|
||||||
|
q_pe,
|
||||||
kv_c_and_k_pe_cache,
|
kv_c_and_k_pe_cache,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
page_table,
|
page_table,
|
||||||
|
|||||||
@@ -86,10 +86,14 @@ def test_cutlass_mla_decode(
|
|||||||
)
|
)
|
||||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||||
|
|
||||||
|
q_nope = torch.empty((h_q, bs, dv)).transpose(0, 1)
|
||||||
|
q_nope.copy_(q[:, :, :dv])
|
||||||
|
q_pe = q[:, :, dv:].clone()
|
||||||
|
|
||||||
out_ref = q.new_zeros(bs, h_q, dv)
|
out_ref = q.new_zeros(bs, h_q, dv)
|
||||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||||
out = cutlass_mla_decode(
|
out = cutlass_mla_decode(
|
||||||
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
q_nope, q_pe, kv_cache, seq_lens, block_table, workspace, num_kv_splits
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2)
|
||||||
|
|||||||
Reference in New Issue
Block a user