Fix one bug in the grouped-gemm triton kernel (#6772)

This commit is contained in:
Cheng Wan
2025-05-30 01:42:08 -07:00
committed by GitHub
parent 69dd878b51
commit b581b22504

View File

@@ -621,7 +621,7 @@ def grouped_gemm_triton_kernel(
b_ptr += BLOCK_SIZE_K
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
scale_a_value = tl.load(scale_a + expert_id)
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None])
scale_b_value = tl.load(scale_b + expert_id)
accumulator *= scale_a_value * scale_b_value