Let ep_scatter support arbitrary strides / ue8m0 format (#7309)
This commit is contained in:
@@ -813,14 +813,17 @@ def _fwd_kernel_ep_scatter_2(
|
||||
offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
|
||||
mask = offset_in < HIDDEN_SIZE
|
||||
|
||||
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||
mask_s = offset_in_s < SCALE_HIDDEN_SIZE
|
||||
index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
|
||||
mask_s = index_in_s < SCALE_HIDDEN_SIZE
|
||||
|
||||
for token_id_int32 in range(start_token_id, total_token_num, grid_num):
|
||||
token_id = token_id_int32.to(tl.int64)
|
||||
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
|
||||
to_copy_s = tl.load(
|
||||
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
|
||||
recv_x_scale
|
||||
+ token_id * recv_x_scale_stride0
|
||||
+ index_in_s * recv_x_scale_stride1,
|
||||
mask=mask_s,
|
||||
)
|
||||
|
||||
for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
|
||||
@@ -841,7 +844,11 @@ def _fwd_kernel_ep_scatter_2(
|
||||
output_tensor_scale + dest_token_index * output_tensor_scale_stride0
|
||||
)
|
||||
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
|
||||
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
|
||||
tl.store(
|
||||
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
|
||||
to_copy_s,
|
||||
mask=mask_s,
|
||||
)
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
|
||||
@@ -856,6 +863,7 @@ def ep_scatter(
|
||||
output_tensor_scale: torch.Tensor,
|
||||
m_indices: torch.Tensor,
|
||||
output_index: torch.Tensor,
|
||||
scale_ue8m0: bool = False,
|
||||
):
|
||||
BLOCK_E = 128 # token num of per expert is aligned to 128
|
||||
BLOCK_D = 128 # block size of quantization
|
||||
@@ -865,7 +873,15 @@ def ep_scatter(
|
||||
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
|
||||
grid = num_experts
|
||||
|
||||
scale_hidden_size = hidden_size // BLOCK_D
|
||||
if scale_ue8m0:
|
||||
# ue8m0 scales are packed here (4 scales per int32),
|
||||
# hence the effective size of this dimension is divided by 4.
|
||||
scale_hidden_size = ceil_div(scale_hidden_size, 4)
|
||||
|
||||
assert m_indices.shape[0] % BLOCK_E == 0
|
||||
assert recv_x_scale.dtype == output_tensor_scale.dtype
|
||||
assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
|
||||
|
||||
_fwd_kernel_ep_scatter_1[(grid,)](
|
||||
num_recv_tokens_per_expert,
|
||||
@@ -904,8 +920,8 @@ def ep_scatter(
|
||||
num_warps=num_warps,
|
||||
HIDDEN_SIZE=hidden_size,
|
||||
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
|
||||
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
|
||||
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
|
||||
SCALE_HIDDEN_SIZE=scale_hidden_size,
|
||||
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user