[sgl-kernel] support flashmla libtorch (#11717)
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)
|
||||
project(sgl-kernel LANGUAGES CXX CUDA)
|
||||
|
||||
# utils
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
include(FetchContent)
|
||||
|
||||
# CMake
|
||||
cmake_policy(SET CMP0169 OLD)
|
||||
cmake_policy(SET CMP0177 NEW)
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
set(CMAKE_COLOR_DIAGNOSTICS ON)
|
||||
set(CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL "ON")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
@@ -37,11 +40,9 @@ endif()
|
||||
|
||||
# Torch
|
||||
find_package(Torch REQUIRED)
|
||||
# clean Torch Flag
|
||||
clear_cuda_arches(CMAKE_FLAG)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# Third Party
|
||||
# cutlass
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
@@ -69,7 +70,7 @@ FetchContent_Declare(
|
||||
)
|
||||
FetchContent_Populate(repo-fmt)
|
||||
|
||||
# Triton
|
||||
# Triton kernel
|
||||
FetchContent_Declare(
|
||||
repo-triton
|
||||
GIT_REPOSITORY "https://github.com/triton-lang/triton"
|
||||
@@ -143,12 +144,6 @@ endif()
|
||||
include_directories(
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
${repo-fast-hadamard-transform}/csrc
|
||||
)
|
||||
|
||||
set(SGL_KERNEL_CUDA_FLAGS
|
||||
@@ -350,6 +345,7 @@ set(SOURCES
|
||||
"${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp"
|
||||
)
|
||||
|
||||
# =========================== Common SM90 Build ============================= #
|
||||
# Build SM90 library with fast math optimization (same namespace, different directory)
|
||||
Python_add_library(common_ops_sm90_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
@@ -360,7 +356,11 @@ target_compile_options(common_ops_sm90_build PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS} -use_fast_math>
|
||||
)
|
||||
target_include_directories(common_ops_sm90_build PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
@@ -371,6 +371,7 @@ set_target_properties(common_ops_sm90_build PROPERTIES
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90"
|
||||
)
|
||||
|
||||
# =========================== Common SM100+ Build ============================= #
|
||||
# Build SM100+ library with precise math (same namespace, different directory)
|
||||
Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
|
||||
@@ -381,7 +382,11 @@ target_compile_options(common_ops_sm100_build PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERNEL_CUDA_FLAGS}>
|
||||
)
|
||||
target_include_directories(common_ops_sm100_build PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/csrc
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flashinfer_SOURCE_DIR}/include
|
||||
${repo-flashinfer_SOURCE_DIR}/csrc
|
||||
${repo-mscclpp_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha
|
||||
${repo-cutlass_SOURCE_DIR}/examples/common
|
||||
${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src
|
||||
@@ -408,7 +413,7 @@ else()
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=1")
|
||||
endif()
|
||||
|
||||
# mscclpp
|
||||
# mscclpp option
|
||||
set(MSCCLPP_USE_CUDA ON)
|
||||
set(MSCCLPP_BYPASS_GPU_CHECK ON)
|
||||
set(MSCCLPP_BUILD_TESTS OFF)
|
||||
@@ -419,7 +424,7 @@ add_subdirectory(
|
||||
target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||
target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static)
|
||||
|
||||
# flash attention
|
||||
# sparse flash attention
|
||||
target_compile_definitions(common_ops_sm90_build PRIVATE
|
||||
FLASHATTENTION_DISABLE_BACKWARD
|
||||
FLASHATTENTION_DISABLE_DROPOUT
|
||||
@@ -506,6 +511,8 @@ if (SGL_KERNEL_ENABLE_FA3)
|
||||
|
||||
target_compile_options(flash_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_FLASH_KERNEL_CUDA_FLAGS}>)
|
||||
target_include_directories(flash_ops PRIVATE
|
||||
${repo-cutlass_SOURCE_DIR}/include
|
||||
${repo-cutlass_SOURCE_DIR}/tools/util/include
|
||||
${repo-flash-attention_SOURCE_DIR}/hopper
|
||||
)
|
||||
target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
@@ -535,6 +542,8 @@ target_compile_options(spatial_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${SGL_KERN
|
||||
target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel)
|
||||
|
||||
# ============================ Extra Install ============================= #
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake)
|
||||
|
||||
# ============================ DeepGEMM (JIT) ============================= #
|
||||
# Create a separate library for DeepGEMM's Python API.
|
||||
|
||||
60
sgl-kernel/cmake/flashmla.cmake
Normal file
60
sgl-kernel/cmake/flashmla.cmake
Normal file
@@ -0,0 +1,60 @@
|
||||
include(FetchContent)
|
||||
|
||||
# flash_mla
|
||||
FetchContent_Declare(
|
||||
repo-flashmla
|
||||
GIT_REPOSITORY https://github.com/sgl-project/FlashMLA
|
||||
GIT_TAG bc8576abc3e507425cf6498f3d3393df7733ce37
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-flashmla)
|
||||
|
||||
set(FLASHMLA_CUDA_FLAGS
|
||||
"--expt-relaxed-constexpr"
|
||||
"--expt-extended-lambda"
|
||||
"--use_fast_math"
|
||||
)
|
||||
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.4 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
if(${CUDA_VERSION} VERSION_GREATER 12.4)
|
||||
list(APPEND FLASHMLA_CUDA_FLAGS
|
||||
"-gencode=arch=compute_90a,code=sm_90a"
|
||||
)
|
||||
endif()
|
||||
if(${CUDA_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND FLASHMLA_CUDA_FLAGS
|
||||
"-gencode=arch=compute_100a,code=sm_100a"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
set(FlashMLA_SOURCES
|
||||
"csrc/flashmla_extension.cc"
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/python_api.cpp
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||
)
|
||||
|
||||
Python_add_library(flashmla_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FlashMLA_SOURCES})
|
||||
target_compile_options(flashmla_ops PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:${FLASHMLA_CUDA_FLAGS}>)
|
||||
target_include_directories(flashmla_ops PRIVATE
|
||||
${repo-flashmla_SOURCE_DIR}/csrc
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${repo-flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
target_link_libraries(flashmla_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda)
|
||||
|
||||
install(TARGETS flashmla_ops LIBRARY DESTINATION "sgl_kernel")
|
||||
|
||||
target_compile_definitions(flashmla_ops PRIVATE)
|
||||
46
sgl-kernel/csrc/flashmla_extension.cc
Normal file
46
sgl-kernel/csrc/flashmla_extension.cc
Normal file
@@ -0,0 +1,46 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
/*
|
||||
* From FlashMLA
|
||||
*/
|
||||
m.def(
|
||||
"get_mla_decoding_metadata(Tensor seqlens_k, int num_q_tokens_per_head_k, int h_k, int? h_q, bool "
|
||||
"is_fp8_kvcache, int? topk) -> Tensor[]");
|
||||
m.impl("get_mla_decoding_metadata", torch::kCUDA, &get_mla_decoding_metadata);
|
||||
|
||||
m.def(
|
||||
"fwd_kvcache_mla(Tensor q, Tensor kv_cache, int head_size_v, Tensor seqlens_k, Tensor block_table, float "
|
||||
"softmax_scale, bool is_causal, Tensor tile_scheduler_metadata, Tensor num_splits, bool is_fp8, Tensor? indices) "
|
||||
"-> Tensor[]");
|
||||
m.impl("fwd_kvcache_mla", torch::kCUDA, &fwd_kvcache_mla);
|
||||
|
||||
m.def(
|
||||
"dense_prefill_fwd(Tensor workspace_buffer, Tensor q, Tensor k, Tensor v, Tensor cumulative_seqlen_q, Tensor "
|
||||
"cumulative_seqlen_kv, Tensor o, Tensor lse, int mask_mode_code, float softmax_scale, int max_seqlen_q, int "
|
||||
"max_seqlen_kv, bool is_varlen) -> ()");
|
||||
m.impl("dense_prefill_fwd", torch::kCUDA, &FMHACutlassSM100FwdRun);
|
||||
|
||||
m.def("sparse_prefill_fwd(Tensor q, Tensor kv, Tensor indices, float sm_scale, int d_v) -> Tensor[]");
|
||||
m.impl("sparse_prefill_fwd", torch::kCUDA, &sparse_prefill_fwd);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(flashmla_ops)
|
||||
@@ -842,6 +842,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
|
||||
/*
|
||||
* From fast-hadamard-transform
|
||||
*/
|
||||
@@ -850,3 +851,47 @@ torch::Tensor fast_hadamard_transform_12N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_20N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_28N(torch::Tensor& x, double scale);
|
||||
torch::Tensor fast_hadamard_transform_40N(torch::Tensor& x, double scale);
|
||||
|
||||
/*
|
||||
* From csrc/fastertransformer
|
||||
*/
|
||||
std::vector<at::Tensor> get_mla_decoding_metadata(
|
||||
at::Tensor& seqlens_k,
|
||||
const int64_t num_q_tokens_per_head_k,
|
||||
const int64_t h_k,
|
||||
const std::optional<int64_t> h_q,
|
||||
const bool is_fp8_kvcache,
|
||||
const std::optional<int64_t> topk);
|
||||
|
||||
std::vector<at::Tensor> fwd_kvcache_mla(
|
||||
at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or
|
||||
// num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
|
||||
const int64_t head_size_v,
|
||||
const at::Tensor& seqlens_k, // batch_size
|
||||
const at::Tensor& block_table, // batch_size x max_num_blocks_per_seq
|
||||
const double softmax_scale,
|
||||
bool is_causal,
|
||||
const at::Tensor& tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
|
||||
const at::Tensor& num_splits, // batch_size + 1
|
||||
const bool& is_fp8,
|
||||
const std::optional<at::Tensor>& indices // None, or batch_size x seqlen_q x topk
|
||||
);
|
||||
|
||||
void FMHACutlassSM100FwdRun(
|
||||
at::Tensor workspace_buffer,
|
||||
at::Tensor q,
|
||||
at::Tensor k,
|
||||
at::Tensor v,
|
||||
at::Tensor cumulative_seqlen_q,
|
||||
at::Tensor cumulative_seqlen_kv,
|
||||
at::Tensor o,
|
||||
at::Tensor lse,
|
||||
int64_t mask_mode_code,
|
||||
double softmax_scale,
|
||||
int64_t max_seqlen_q,
|
||||
int64_t max_seqlen_kv,
|
||||
bool is_varlen);
|
||||
|
||||
std::vector<at::Tensor>
|
||||
sparse_prefill_fwd(const at::Tensor& q, const at::Tensor& kv, const at::Tensor& indices, double sm_scale, int64_t d_v);
|
||||
|
||||
126
sgl-kernel/python/sgl_kernel/flash_mla.py
Normal file
126
sgl-kernel/python/sgl_kernel/flash_mla.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from . import flashmla_ops # triggers TORCH extension registration
|
||||
except Exception as _e:
|
||||
_flashmla_import_error = _e
|
||||
else:
|
||||
_flashmla_import_error = None
|
||||
|
||||
_IMPORT_ERROR = ImportError(
|
||||
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
|
||||
)
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
num_heads_k: The number of k heads.
|
||||
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
|
||||
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Returns:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops.sgl_kernel.get_mla_decoding_metadata.default(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
num_heads_q,
|
||||
is_fp8_kvcache,
|
||||
topk,
|
||||
)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
cache_seqlens: torch.Tensor,
|
||||
head_dim_v: int,
|
||||
tile_scheduler_metadata: torch.Tensor,
|
||||
num_splits: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head dimension of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
|
||||
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Returns:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
assert causal == False, "causal must be `false` if sparse attention is enabled."
|
||||
out, softmax_lse = torch.ops.sgl_kernel.fwd_kvcache_mla.default(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
is_fp8_kvcache,
|
||||
indices,
|
||||
)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_fwd(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
q: [s_q, h_q, d_qk], bfloat16
|
||||
kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
|
||||
sm_scale: float
|
||||
d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
(output, max_logits, lse)
|
||||
About the definition of output, max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops.sgl_kernel.sparse_prefill_fwd.default(
|
||||
q, kv, indices, sm_scale, d_v
|
||||
)
|
||||
return results
|
||||
518
sgl-kernel/tests/test_flashmla.py
Normal file
518
sgl-kernel/tests/test_flashmla.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import math
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
from sgl_kernel.flash_mla import (
|
||||
flash_mla_sparse_fwd,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
|
||||
skip_condition = torch.cuda.get_device_capability() < (10, 0)
|
||||
|
||||
# ================ prefill usage ================ #
|
||||
S_Q_PREFILL = [1, 62]
|
||||
KV_TOPK_PREFILL = [
|
||||
# Regular shapes
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(512, 512),
|
||||
# Irregular shapes
|
||||
(592, 128),
|
||||
(1840, 256),
|
||||
(1592, 384),
|
||||
(1521, 512),
|
||||
# Irregular shapes with OOB TopK
|
||||
(95, 128),
|
||||
(153, 256),
|
||||
(114, 384),
|
||||
]
|
||||
|
||||
# ================= decode usage ================= #
|
||||
B_DECODE = [1, 2, 6, 64]
|
||||
S_Q_DECODE = [1, 2, 4]
|
||||
S_K_DECODE = [20, 140, 4096]
|
||||
IS_VARLEN = [False, True]
|
||||
CAUSAL_TOPK = [(True, None), (False, None), (False, 128), (False, 2048)]
|
||||
DTYPE = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def quantize_k_cache(
|
||||
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
|
||||
dv: int,
|
||||
tile_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quantize the k-cache
|
||||
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
|
||||
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, d = input_k_cache.shape
|
||||
assert h_k == 1
|
||||
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
|
||||
input_elem_size = input_k_cache.element_size()
|
||||
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device=input_k_cache.device,
|
||||
)
|
||||
result_k_nope_part = result[..., :dv]
|
||||
result_k_scale_factor = result[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
result_k_rope_part = result[..., dv + num_tiles * 4 :].view(input_k_cache.dtype)
|
||||
result_k_rope_part[:] = input_k_cache[..., dv:]
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_scale_factors_inv = (
|
||||
torch.abs(
|
||||
input_k_cache[..., tile_idx * tile_size : (tile_idx + 1) * tile_size]
|
||||
)
|
||||
.max(dim=-1)
|
||||
.values
|
||||
/ 448.0
|
||||
) # [num_blocks, block_size]
|
||||
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
|
||||
|
||||
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
|
||||
cur_quantized_nope = (
|
||||
input_k_cache[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].float()
|
||||
/ cur_scale_factors_inv.float()
|
||||
).to(torch.float8_e4m3fn)
|
||||
result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_quantized_nope
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, -1)
|
||||
return result
|
||||
|
||||
|
||||
def dequantize_k_cache(
|
||||
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
|
||||
dv: int = 512,
|
||||
tile_size: int = 128,
|
||||
d: int = 576,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
De-quantize the k-cache
|
||||
"""
|
||||
assert dv % tile_size == 0
|
||||
num_tiles = dv // tile_size
|
||||
num_blocks, block_size, h_k, _ = quant_k_cache.shape
|
||||
assert h_k == 1
|
||||
result = torch.empty(
|
||||
(num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device
|
||||
)
|
||||
|
||||
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
|
||||
|
||||
input_nope = quant_k_cache[..., :dv]
|
||||
input_scale = quant_k_cache[..., dv : dv + num_tiles * 4].view(torch.float32)
|
||||
input_rope = quant_k_cache[..., dv + num_tiles * 4 :].view(torch.bfloat16)
|
||||
result[..., dv:] = input_rope
|
||||
|
||||
for tile_idx in range(0, num_tiles):
|
||||
cur_nope = input_nope[
|
||||
..., tile_idx * tile_size : (tile_idx + 1) * tile_size
|
||||
].to(torch.float32)
|
||||
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
|
||||
result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = (
|
||||
cur_nope * cur_scales
|
||||
)
|
||||
|
||||
result = result.view(num_blocks, block_size, 1, d)
|
||||
return result
|
||||
|
||||
|
||||
def cdiv(x: int, y: int):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def get_window_size(causal, window):
|
||||
if window > 0:
|
||||
window_size = (window - 1, 0) if causal else (window - 1, window - 1)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
return window_size
|
||||
|
||||
|
||||
def get_attn_bias(s_q, s_k, causal, window):
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float32, device="cuda")
|
||||
if causal:
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
if window > 0:
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q - window
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask, float("-inf"))
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda").tril(
|
||||
diagonal=s_k - s_q + window - 1
|
||||
)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
return attn_bias
|
||||
|
||||
|
||||
def sdpa(query, key, value, attn_bias, softmax_scale=None):
|
||||
query = query.float().transpose(-3, -2)
|
||||
key = key.float().transpose(-3, -2)
|
||||
value = value.float().transpose(-3, -2)
|
||||
key = key.repeat_interleave(h // h_k, dim=-3)
|
||||
value = value.repeat_interleave(h // h_k, dim=-3)
|
||||
if softmax_scale is None:
|
||||
softmax_scale = query.shape[-1] ** (-0.5)
|
||||
attn_weight = (query @ key.transpose(-2, -1)) * softmax_scale
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight.to(query.dtype) @ value, lse
|
||||
|
||||
|
||||
def sdpa_checkpoint(*args, **kwargs):
|
||||
return checkpoint(sdpa, *args, use_reentrant=False, **kwargs)
|
||||
|
||||
|
||||
def reference_torch_prefill(
|
||||
s_q, s_kv, topk, indices, q, kv, sm_scale: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
indices = indices[0, :, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= s_kv)
|
||||
qs = q[0, :, :, :].float() # [s_q, h_q, d_qk]
|
||||
kvs = kv[0, :, 0, :].float() # [s_kv, d_qk]
|
||||
|
||||
kvs = torch.index_select(
|
||||
kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()
|
||||
).view(
|
||||
s_q, topk, 576
|
||||
) # [s_q, topk, d_qk]
|
||||
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score @ kvs[:, :, :512]
|
||||
return (max_logits, lse, result)
|
||||
|
||||
|
||||
def reference_torch_decode(
|
||||
cache_seqlens: torch.Tensor, # [batch_size]
|
||||
block_table: torch.Tensor, # [batch_size, ?]
|
||||
q: torch.Tensor, # [batch_size, s_q, h_q, d]
|
||||
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
|
||||
dv: int,
|
||||
is_causal: bool,
|
||||
indices: Optional[torch.Tensor] = None, # [batch_size, s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
A reference implementation in PyTorch
|
||||
"""
|
||||
|
||||
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
|
||||
mask = torch.zeros(s_q, s_k, dtype=torch.bool, device="cuda")
|
||||
for i in range(s_q):
|
||||
cur_indices = indices[i]
|
||||
valid_indices = cur_indices[cur_indices != -1]
|
||||
mask[i, valid_indices] = True
|
||||
return mask
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
batch_idx: int,
|
||||
query: torch.Tensor, # [h_q, s_q, d]
|
||||
kv: torch.Tensor, # [h_kv, s_k, d]
|
||||
dv: int,
|
||||
is_causal,
|
||||
indices: Optional[torch.Tensor], # [s_q, topk]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
h_q = query.size(0)
|
||||
h_kv = kv.size(0)
|
||||
s_q = query.shape[-2]
|
||||
s_k = kv.shape[-2]
|
||||
query = query.float()
|
||||
kv = kv.float()
|
||||
if h_kv != 1:
|
||||
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
|
||||
kv[kv != kv] = 0.0
|
||||
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
|
||||
if (is_causal and query.size(1) > 1) or indices is not None:
|
||||
mask = torch.ones(s_q, s_k, dtype=torch.bool, device="cuda")
|
||||
if is_causal:
|
||||
assert indices is None
|
||||
mask = mask.tril(diagonal=s_k - s_q)
|
||||
if indices is not None:
|
||||
mask &= get_topk_attn_mask(s_q, s_k, indices)
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float, device="cuda")
|
||||
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
|
||||
attn_weight += attn_bias.to(q.dtype)
|
||||
attn_weight /= math.sqrt(query.size(-1))
|
||||
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
|
||||
# Correct for q tokens which has no attendable k
|
||||
lonely_q_mask = lse == float("-inf")
|
||||
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
|
||||
lse[lonely_q_mask] = float("+inf")
|
||||
|
||||
return output, lse
|
||||
|
||||
b, s_q, h_q, d = q.size()
|
||||
block_size = blocked_k.size(1)
|
||||
h_kv = blocked_k.size(2)
|
||||
cache_seqlens_cpu = cache_seqlens.cpu()
|
||||
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device="cuda")
|
||||
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32, device="cuda")
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
cur_block_indices = block_table[i][0:cur_num_blocks]
|
||||
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
|
||||
cur_out, cur_lse = scaled_dot_product_attention(
|
||||
i,
|
||||
q[i].transpose(0, 1),
|
||||
cur_kv.transpose(0, 1),
|
||||
dv,
|
||||
is_causal,
|
||||
indices[i] if indices is not None else None,
|
||||
)
|
||||
out_ref[i] = cur_out.transpose(0, 1)
|
||||
lse_ref[i] = cur_lse
|
||||
out_ref = out_ref.to(torch.bfloat16)
|
||||
return out_ref, lse_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("s_q", S_Q_PREFILL)
|
||||
@pytest.mark.parametrize("kv_topk", KV_TOPK_PREFILL)
|
||||
@torch.inference_mode()
|
||||
def test_flashmla_prefill(
|
||||
s_q: int,
|
||||
kv_topk: Tuple[int, int],
|
||||
):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
q = torch.randn((1, s_q, 128, 576), dtype=torch.bfloat16, device="cuda") / 10
|
||||
kv = torch.randn((1, kv_topk[0], 1, 576), dtype=torch.bfloat16, device="cuda") / 10
|
||||
|
||||
q.clamp_(-10, 10)
|
||||
kv.clamp_(-10, 10)
|
||||
|
||||
indices = torch.full(
|
||||
(1, s_q, 1, kv_topk[1]), kv_topk[0], dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for s in range(s_q):
|
||||
# NOTE We use the following method to generate indices so that most indices lies within [s_kv-20000, s_kv), which is more realistic for sparse attention
|
||||
near_mask = (
|
||||
torch.randint(0, 32, (min(kv_topk[1], kv_topk[0]),), device="cuda") < 31
|
||||
)
|
||||
cur_indices = torch.randperm(kv_topk[0], device="cuda")[: kv_topk[1]]
|
||||
cur_indices[near_mask] = torch.randint(
|
||||
max(0, kv_topk[0] - 20000),
|
||||
kv_topk[0] - 1,
|
||||
(near_mask.sum().item(),),
|
||||
device="cuda",
|
||||
)
|
||||
if len(cur_indices) < kv_topk[1]:
|
||||
cur_indices = torch.cat(
|
||||
[
|
||||
cur_indices,
|
||||
torch.full(
|
||||
(kv_topk[1] - len(cur_indices),), 2147480000, device="cuda"
|
||||
),
|
||||
]
|
||||
)
|
||||
cur_indices = cur_indices[torch.randperm(kv_topk[1], device="cuda")]
|
||||
indices[0, s, 0] = cur_indices
|
||||
indices = indices.to(q.device)
|
||||
|
||||
sm_scale = 1 / math.sqrt(576)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
ans_out, ans_max_logits, ans_lse = flash_mla_sparse_fwd(
|
||||
q.squeeze(0), kv.squeeze(0), indices.squeeze(0), sm_scale=sm_scale
|
||||
)
|
||||
|
||||
ans_out, ans_max_logits, ans_lse = (
|
||||
ans_out.float(),
|
||||
ans_max_logits.float(),
|
||||
ans_lse.float(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
ref_max_logits, ref_lse, ref_out = reference_torch_prefill(
|
||||
s_q, kv_topk[0], kv_topk[1], indices, q, kv, sm_scale
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(ans_out, ref_out, atol=8e-4, rtol=2.01 / 128)
|
||||
torch.testing.assert_close(
|
||||
ans_max_logits,
|
||||
ref_max_logits,
|
||||
atol=1e-6,
|
||||
rtol=2.01 / 65536,
|
||||
)
|
||||
torch.testing.assert_close(ans_lse, ref_lse, atol=1e-6, rtol=2.01 / 65536)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("b", B_DECODE)
|
||||
@pytest.mark.parametrize("s_q", S_Q_DECODE)
|
||||
@pytest.mark.parametrize("s_k", S_K_DECODE)
|
||||
@pytest.mark.parametrize("is_varlen", IS_VARLEN)
|
||||
@pytest.mark.parametrize("causal_topk", CAUSAL_TOPK)
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla_decode(
|
||||
b: int,
|
||||
s_q: int,
|
||||
s_k: int,
|
||||
is_varlen: bool,
|
||||
causal_topk: Tuple[bool, Optional[int]],
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
d = 576
|
||||
dv = 512
|
||||
block_size = 64
|
||||
h_q = 128
|
||||
h_kv = 1
|
||||
is_causal = causal_topk[0]
|
||||
topk = causal_topk[1]
|
||||
|
||||
# Generating test data
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu")
|
||||
if is_varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), s_q)
|
||||
|
||||
max_seqlen = cache_seqlens_cpu.max().item()
|
||||
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
|
||||
cache_seqlens = cache_seqlens_cpu.cuda()
|
||||
|
||||
q = torch.randn(b, s_q, 128, d, dtype=torch.bfloat16, device="cuda")
|
||||
q.clamp_(min=-1.0, max=1.0)
|
||||
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32, device="cuda"
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(b, -1)
|
||||
blocked_k = (
|
||||
torch.randn(
|
||||
block_table.numel(),
|
||||
block_size,
|
||||
h_kv,
|
||||
d,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
/ 10
|
||||
)
|
||||
blocked_k.clamp_(min=-1.0, max=1.0)
|
||||
|
||||
if topk is None:
|
||||
for i in range(b):
|
||||
cur_len = cache_seqlens_cpu[i].item()
|
||||
cur_num_blocks = cdiv(cur_len, block_size)
|
||||
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
|
||||
if cur_len % block_size != 0:
|
||||
blocked_k[block_table[i][cur_num_blocks - 1]][
|
||||
cur_len % block_size :
|
||||
] = float("nan")
|
||||
block_table[i][cur_num_blocks:] = 2147480000
|
||||
abs_indices = None
|
||||
indices_in_kvcache = None
|
||||
else:
|
||||
block_table_cpu = block_table.cpu()
|
||||
abs_indices = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
|
||||
indices_in_kvcache = torch.empty(b, s_q, topk, dtype=torch.int32, device="cpu")
|
||||
for i in range(b):
|
||||
# Generate indices
|
||||
for j in range(s_q):
|
||||
cur_abs_indices = torch.randperm(
|
||||
int(cache_seqlens_cpu[i].item()), device="cpu"
|
||||
)[:topk]
|
||||
cur_blocked_indices = block_table_cpu[
|
||||
i, cur_abs_indices // block_size
|
||||
] * block_size + (cur_abs_indices % block_size)
|
||||
if len(cur_abs_indices) < topk:
|
||||
pad_len = topk - len(cur_abs_indices)
|
||||
cur_abs_indices = torch.cat(
|
||||
[cur_abs_indices, torch.full((pad_len,), -1, device="cpu")]
|
||||
)
|
||||
cur_blocked_indices = torch.cat(
|
||||
[cur_blocked_indices, torch.full((pad_len,), -1, device="cpu")]
|
||||
)
|
||||
|
||||
# Mask KV
|
||||
perm = torch.randperm(topk, device="cpu")
|
||||
cur_abs_indices = cur_abs_indices[perm]
|
||||
cur_blocked_indices = cur_blocked_indices[perm]
|
||||
|
||||
abs_indices[i, j, :] = cur_abs_indices
|
||||
indices_in_kvcache[i, j, :] = cur_blocked_indices
|
||||
|
||||
# Mask nonused KV as NaN
|
||||
all_indices = indices_in_kvcache.flatten().tolist()
|
||||
all_indices = list(set(all_indices))
|
||||
if -1 in all_indices:
|
||||
all_indices.remove(-1)
|
||||
all_indices = torch.tensor(all_indices, dtype=torch.int32, device="cpu")
|
||||
|
||||
blocked_k = blocked_k.view(-1, h_kv, d)
|
||||
nonused_indices_mask = torch.ones(
|
||||
blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu"
|
||||
)
|
||||
nonused_indices_mask[all_indices] = False
|
||||
blocked_k[nonused_indices_mask, :, :] = float("nan")
|
||||
blocked_k = blocked_k.view(-1, block_size, h_kv, d)
|
||||
|
||||
abs_indices = abs_indices.to(q.device)
|
||||
indices_in_kvcache = indices_in_kvcache.to(q.device)
|
||||
|
||||
is_fp8 = topk is not None
|
||||
if is_fp8:
|
||||
# The quantization error may be too large to be distinguished from wrong kernels
|
||||
# So we quantize and de-quantize kv-cache here to mitigate quantization error
|
||||
blocked_k_quantized = quantize_k_cache(blocked_k, dv, 128)
|
||||
blocked_k_dequantized = dequantize_k_cache(blocked_k_quantized)
|
||||
blocked_k = blocked_k_dequantized
|
||||
|
||||
# Get schedule metadata
|
||||
torch.cuda.synchronize()
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q, is_fp8, topk
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out_ans, lse_ans = flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k if not is_fp8 else blocked_k_quantized, # type: ignore
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=is_causal,
|
||||
is_fp8_kvcache=is_fp8,
|
||||
indices=indices_in_kvcache,
|
||||
)
|
||||
|
||||
out_ref, lse_ref = reference_torch_decode(
|
||||
cache_seqlens, block_table, q, blocked_k, dv, is_causal, abs_indices
|
||||
)
|
||||
torch.testing.assert_close(out_ans, out_ref, atol=8e-4, rtol=2.01 / 128)
|
||||
torch.testing.assert_close(lse_ans, lse_ref, atol=1e-6, rtol=8.01 / 65536)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user