Fix one bug in the grouped-gemm triton kernel (#6772)
This commit is contained in:
@@ -621,7 +621,7 @@ def grouped_gemm_triton_kernel(
|
||||
b_ptr += BLOCK_SIZE_K
|
||||
|
||||
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
||||
scale_a_value = tl.load(scale_a + expert_id)
|
||||
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None])
|
||||
scale_b_value = tl.load(scale_b + expert_id)
|
||||
accumulator *= scale_a_value * scale_b_value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user