Fix one bug in the grouped-gemm triton kernel (#6772)
This commit is contained in:
@@ -621,7 +621,7 @@ def grouped_gemm_triton_kernel(
|
|||||||
b_ptr += BLOCK_SIZE_K
|
b_ptr += BLOCK_SIZE_K
|
||||||
|
|
||||||
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
|
||||||
scale_a_value = tl.load(scale_a + expert_id)
|
scale_a_value = tl.load(scale_a + m_range_start + offs_am[:, None])
|
||||||
scale_b_value = tl.load(scale_b + expert_id)
|
scale_b_value = tl.load(scale_b + expert_id)
|
||||||
accumulator *= scale_a_value * scale_b_value
|
accumulator *= scale_a_value * scale_b_value
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user