Disable kernel cutlass_mla_decode on SM103 (#10058)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||||
|
#include "utils.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||||
@@ -217,6 +218,10 @@ void cutlass_mla_decode(
|
|||||||
torch::Tensor const& workspace,
|
torch::Tensor const& workspace,
|
||||||
double sm_scale,
|
double sm_scale,
|
||||||
int64_t num_kv_splits) {
|
int64_t num_kv_splits) {
|
||||||
|
auto sm_version = getSMVersion();
|
||||||
|
// On SM103a, half of the accuracy tests are failing.
|
||||||
|
TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version);
|
||||||
|
|
||||||
auto in_dtype = q_nope.dtype();
|
auto in_dtype = q_nope.dtype();
|
||||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import torch.nn.functional as F
|
|||||||
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
if torch.cuda.get_device_capability() < (10, 0):
|
# Disable tests on SM103 until the accuracy issues are fixed.
|
||||||
|
if torch.cuda.get_device_capability() != (10, 0):
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
reason="Cutlass MLA Requires compute capability of 10 or above.",
|
reason="Cutlass MLA Requires compute capability of 10.",
|
||||||
allow_module_level=True,
|
allow_module_level=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user