Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)
This commit is contained in:
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
|
||||
rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k))
|
||||
rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128)))
|
||||
|
||||
if not is_sm100_supported():
|
||||
rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets)
|
||||
w1_scale = w1_scale.contiguous()
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
|
||||
if not is_sm100_supported():
|
||||
a2_scale = per_group_transpose(a2_scale, expert_offsets)
|
||||
w2_scale = w2_scale.contiguous()
|
||||
|
||||
fp8_blockwise_scaled_grouped_mm(
|
||||
c2,
|
||||
|
||||
@@ -8,6 +8,15 @@ from transformers import AutoConfig
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
|
||||
|
||||
|
||||
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def get_model_config(tp_size: int):
|
||||
@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
|
||||
# --- Input Data ---
|
||||
# Use bf16/fp16 for input activation based on model config
|
||||
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
|
||||
x = torch.randn((batch_size, H), device="cuda", dtype=dtype)
|
||||
# --- Weights (Generate in higher precision, then convert to FP8) ---
|
||||
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
|
||||
w1_hp = (
|
||||
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
|
||||
)
|
||||
w2_hp = (
|
||||
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
|
||||
+ 0.00001
|
||||
)
|
||||
w1_hp = torch.randn((E, I, H), device="cuda", dtype=torch.float32)
|
||||
w2_hp = torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32)
|
||||
|
||||
w1 = to_fp8(w1_hp)
|
||||
w2 = to_fp8(w2_hp)
|
||||
@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
)
|
||||
|
||||
# Note: Triton expects non-transposed weights
|
||||
moe_config = MoeRunnerConfig(inplace=False)
|
||||
triton_lambda = lambda: fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False,
|
||||
activation="silu", # Assuming SiLU activation common in MoEs
|
||||
moe_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
|
||||
w1, # Original shape
|
||||
w2, # Original shape
|
||||
(topk_weights, topk_ids, "dummy"),
|
||||
inplace=False, # Important: Use False to get output tensor
|
||||
activation="silu",
|
||||
moe_config,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Ensure outputs are same dtype for comparison
|
||||
y_cutlass = y_cutlass.to(dtype)
|
||||
y_triton = y_triton.to(dtype)
|
||||
|
||||
abs_error = torch.abs(y_cutlass - y_triton)
|
||||
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
|
||||
|
||||
max_abs_err = abs_error.max().item()
|
||||
max_rel_err = rel_error.max().item()
|
||||
|
||||
print("y_cutlass:", y_cutlass[:, :10])
|
||||
print("y_triton:", y_triton[:, :10])
|
||||
print(f"Max absolute error: {max_abs_err:.6f}")
|
||||
print(f"Max relative error: {max_rel_err:.6f}")
|
||||
diff = calc_diff(y_cutlass, y_triton)
|
||||
print(f"Diff: {diff:.6f}")
|
||||
|
||||
# Tolerance might need adjustment based on FP8 specifics and kernel differences
|
||||
# FP8 comparisons often require higher tolerance than FP16/BF16
|
||||
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
|
||||
assert diff < 1e-4, f"Diff too high! {diff}"
|
||||
print("Correctness check passed.")
|
||||
|
||||
|
||||
@@ -264,7 +255,21 @@ if __name__ == "__main__":
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default
|
||||
default=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
], # Adjusted default
|
||||
help="List of batch sizes to test",
|
||||
)
|
||||
parser.add_argument("--check", action="store_true", help="Enable check mode")
|
||||
|
||||
@@ -45,7 +45,7 @@ include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
GIT_TAG 664c4f7b3ed1959414905025728eef5568209479
|
||||
GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
struct MmaConfig0 {
|
||||
struct MmaConfigSmallM {
|
||||
// Swap A/B
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
// TODO: Check Pingpong or Cooperative
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfigH20LargeK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct MmaConfig1 {
|
||||
struct MmaConfigHx00AndH20SmallK {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
// [NOTE] default for H20
|
||||
struct MmaConfigH20_default {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>;
|
||||
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
|
||||
torch::Tensor output_t = output.t();
|
||||
torch::Tensor a_t = a.t();
|
||||
torch::Tensor b_t = b.transpose(1, 2);
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
const std::string H20_device_type_str = "NVIDIA H20";
|
||||
bool is_h20_device = isDeviceType(H20_device_type_str);
|
||||
const std::string H20_device_type_str("NVIDIA H20");
|
||||
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
|
||||
|
||||
if (is_h20_device) {
|
||||
using execute_gemm_config = MmaConfigH20_default;
|
||||
run_get_group_gemm_starts<
|
||||
execute_gemm_config::LayoutSFA,
|
||||
execute_gemm_config::LayoutSFB,
|
||||
execute_gemm_config::ScaleConfig>(
|
||||
if (a.size(0) <= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfigSmallM::LayoutSFA, MmaConfigSmallM::LayoutSFB, MmaConfigSmallM::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
out_ptrs,
|
||||
a_scales_ptrs,
|
||||
b_scales_ptrs,
|
||||
a,
|
||||
b,
|
||||
output,
|
||||
scales_a,
|
||||
scales_b,
|
||||
b_t,
|
||||
a_t,
|
||||
output_t,
|
||||
scales_b_t,
|
||||
scales_a_t,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, execute_gemm_config, cutlass::layout::RowMajor>(
|
||||
problem_sizes_transpose,
|
||||
true);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigSmallM, cutlass::layout::ColumnMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
stride_c,
|
||||
layout_sfa,
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose,
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else {
|
||||
if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) {
|
||||
if (is_h20_device && a.size(1) > 128) {
|
||||
// For H20 with K > 128, use Pingpong Schedule
|
||||
run_get_group_gemm_starts<MmaConfig0::LayoutSFA, MmaConfig0::LayoutSFB, MmaConfig0::ScaleConfig>(
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigH20LargeK::LayoutSFA,
|
||||
MmaConfigH20LargeK::LayoutSFB,
|
||||
MmaConfigH20LargeK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig0, cutlass::layout::RowMajor>(
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigH20LargeK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
workspace);
|
||||
} else {
|
||||
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
run_get_group_gemm_starts<
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFA,
|
||||
MmaConfigHx00AndH20SmallK::LayoutSFB,
|
||||
MmaConfigHx00AndH20SmallK::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
|
||||
layout_sfb,
|
||||
problem_sizes,
|
||||
problem_sizes_transpose);
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::RowMajor>(
|
||||
launch_sm90_fp8_blockwise_scaled_group_mm<OutType, MmaConfigHx00AndH20SmallK, cutlass::layout::RowMajor>(
|
||||
out_ptrs,
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
|
||||
@@ -5,10 +5,6 @@ import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_grouped_mm
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8_hopper_moe_mn_major,
|
||||
)
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool:
|
||||
not (is_sm100_supported() or is_sm90_supported()),
|
||||
reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90",
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_custom_kernel", [True, False])
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel):
|
||||
cc = torch.cuda.get_device_capability(None)[0]
|
||||
if cc == 10 and use_custom_kernel:
|
||||
return
|
||||
def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
device = "cuda"
|
||||
alignment = 16
|
||||
n_g = alignment * random.randint(1, 5) * 128
|
||||
k_g = alignment * random.randint(1, 5) * 128
|
||||
alignment = 128
|
||||
n_g = random.randint(1, 64) * 128
|
||||
k_g = random.randint(1, 64) * 128
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32)
|
||||
|
||||
a_original_tensors = []
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
|
||||
baseline_tensors = []
|
||||
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
m_g = random.randint(1, 256)
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device)
|
||||
|
||||
@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
|
||||
b_g, b_scale = per_block_cast_to_fp8(
|
||||
b
|
||||
) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
|
||||
a_original_tensors.append(a)
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
a_scales_tensors.append(a_scale)
|
||||
@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
|
||||
|
||||
baseline = torch.mm(a, b)
|
||||
baseline_tensors.append(baseline)
|
||||
a_original_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=out_dtype
|
||||
)
|
||||
a_stack = torch.empty(
|
||||
(expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
a_scale_stack = torch.empty(
|
||||
(expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32
|
||||
(expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32
|
||||
)
|
||||
b_scale_stack = torch.empty(
|
||||
(num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32
|
||||
(num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
# Matrix A is Row-Major.
|
||||
a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_original_tensors[g]
|
||||
)
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[
|
||||
a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[
|
||||
g
|
||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
|
||||
] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1)
|
||||
b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1)
|
||||
if cc == 9:
|
||||
# For SM90, we need MN-Major scale factor
|
||||
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
|
||||
a_scale_stack[
|
||||
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
|
||||
] = (a_scales_tensors[g].t().contiguous().view(-1))
|
||||
b_scale_stack[g] = b_scales_tensors[g] # b_scale_stack[g] -- (k, n):(n, 1)
|
||||
elif cc == 10:
|
||||
# For SM100, we need K-Major scale factor
|
||||
# a_scales_tensors[g] -- (M, k):(k, 1)
|
||||
a_scale_stack[
|
||||
expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128)
|
||||
] = a_scales_tensors[g].view(-1)
|
||||
b_scale_stack[g] = b_scales_tensors[
|
||||
g
|
||||
] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||
a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128)
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
if cc == 10:
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2).contiguous()
|
||||
|
||||
if use_custom_kernel:
|
||||
# Replace a_stack, a_scale_stack with custom kernel output
|
||||
a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major(
|
||||
a_original_stack,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
128,
|
||||
expert_tokens_alignment=alignment,
|
||||
)
|
||||
# We need K-Major scale factor
|
||||
a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[
|
||||
g
|
||||
]
|
||||
b_scale_stack[g] = b_scales_tensors[
|
||||
g
|
||||
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
|
||||
diff = calc_diff(actual, baseline)
|
||||
assert diff < 0.001
|
||||
print(
|
||||
f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
||||
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user