[1/N][Bug] Fix w4afp8 MoE NaN issue (sgl-kernel, fixed) (#10108)
This commit is contained in:
@@ -41,8 +41,8 @@ using MmaType = cutlass::float_e4m3_t; // FP8 e4m3 type
|
||||
using QuantType = cutlass::int4b_t; // 4-bit integer type
|
||||
using ElementAccumulator = float; // Accumulator type
|
||||
using ElementScale = cutlass::bfloat16_t; // Scale type
|
||||
using ElementC = cutlass::half_t; // Default output type (FP16)
|
||||
using ElementD = ElementC; // Default output type (FP16)
|
||||
using ElementC = cutlass::bfloat16_t; // Output type
|
||||
using ElementD = ElementC; // Output type
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Architecture-specific configurations
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size):
|
||||
a_q = torch.clamp((a / a_scale), -448.0, 448.0).to(torch.float8_e4m3fn).to(device)
|
||||
|
||||
# Create output tensor
|
||||
c = torch.empty((m, n), dtype=torch.float16, device=device)
|
||||
c = torch.empty((m, n), dtype=torch.bfloat16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c,
|
||||
a_q,
|
||||
@@ -211,7 +211,7 @@ def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||
b_strides = a_strides
|
||||
s_strides = c_strides
|
||||
|
||||
c_perm = torch.empty((batch_size, n), dtype=torch.float16, device=device)
|
||||
c_perm = torch.empty((batch_size, n), dtype=torch.bfloat16, device=device)
|
||||
cutlass_w4a8_moe_mm(
|
||||
c_perm,
|
||||
a_q_perm,
|
||||
@@ -262,10 +262,9 @@ def ref_grouped_gemm(c, a, a_scale, w, w_scale, num_experts, experts_selection_r
|
||||
continue
|
||||
a = a_q[token_idx]
|
||||
|
||||
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(float)
|
||||
ref_w = (w[i].to(float) * ref_w_scale_repeat).to(dtype)
|
||||
c = torch.matmul(a.to(dtype), ref_w.t().to(dtype)) * a_scale
|
||||
c = c.to(dtype)
|
||||
ref_w_scale_repeat = w_scale[i].repeat_interleave(128, dim=1).to(torch.float32)
|
||||
ref_w = w[i].to(torch.float32) * ref_w_scale_repeat
|
||||
c = torch.matmul(a.to(torch.float32), ref_w.t()) * a_scale
|
||||
c_ref[token_idx] = c.to(dtype)
|
||||
|
||||
return c_ref
|
||||
|
||||
Reference in New Issue
Block a user