[BugFix] fix pre_reorder_triton_kernel default int32 issue (#7814)
This commit is contained in:
@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
|
|||||||
):
|
):
|
||||||
OutDtype = gateup_input_ptr.dtype.element_ty
|
OutDtype = gateup_input_ptr.dtype.element_ty
|
||||||
|
|
||||||
src_idx = tl.program_id(0)
|
src_idx_int32 = tl.program_id(0)
|
||||||
|
src_idx = src_idx_int32.to(tl.int64)
|
||||||
src2dst_ptr = src2dst_ptr + src_idx * topk
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
||||||
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
||||||
src_ptr = input_ptr + src_idx * hidden_size
|
src_ptr = input_ptr + src_idx * hidden_size
|
||||||
@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
|
|||||||
else:
|
else:
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
|
|
||||||
dst_idx = tl.load(src2dst_ptr + idx)
|
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
||||||
|
dst_idx = dst_idx_int32.to(tl.int64)
|
||||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
offset = start_offset + vec
|
offset = start_offset + vec
|
||||||
|
|||||||
Reference in New Issue
Block a user