feat: support moe_align_block_size_triton (#2712)
Co-authored-by: WANDY666 <1060304770@qq.com>
This commit is contained in:
@@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts
|
|||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||||
|
|
||||||
not_hip = False
|
is_hip_flag = False
|
||||||
if not is_hip():
|
if not is_hip():
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
|
|
||||||
not_hip = True
|
is_hip_flag = False
|
||||||
|
else:
|
||||||
|
is_hip_flag = True
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
|
|
||||||
|
enable_moe_align_block_size_triton = bool(
|
||||||
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_kernel(
|
def fused_moe_kernel(
|
||||||
@@ -222,6 +228,139 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(a, b):
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage1(
|
||||||
|
topk_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
|
||||||
|
off_c = (pid + 1) * num_experts
|
||||||
|
|
||||||
|
for i in range(tokens_per_thread):
|
||||||
|
if start_idx + i < numel:
|
||||||
|
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||||
|
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage2(
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
last_cnt = 0
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||||
|
last_cnt = last_cnt + token_cnt
|
||||||
|
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage3(
|
||||||
|
total_tokens_post_pad_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
last_cumsum = 0
|
||||||
|
off_cnt = num_experts * num_experts
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||||
|
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||||
|
tl.store(cumsum_ptr + i, last_cumsum)
|
||||||
|
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage4(
|
||||||
|
topk_ids_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cumsum_ptr + pid)
|
||||||
|
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||||
|
|
||||||
|
for i in range(start_idx, end_idx, block_size):
|
||||||
|
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
off_t = pid * num_experts
|
||||||
|
|
||||||
|
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||||
|
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||||
|
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||||
|
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_align_block_size_triton(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
block_size: int,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_pad: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
numel = topk_ids.numel()
|
||||||
|
grid = (num_experts,)
|
||||||
|
tokens_cnts = torch.zeros(
|
||||||
|
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
||||||
|
tokens_per_thread = ceil_div(numel, num_experts)
|
||||||
|
|
||||||
|
moe_align_block_size_stage1[grid](
|
||||||
|
topk_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage2[grid](
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage3[(1,)](
|
||||||
|
num_tokens_post_pad,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage4[grid](
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
@@ -272,24 +411,36 @@ def moe_align_block_size(
|
|||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
if not_hip and num_experts >= 224:
|
if num_experts >= 224:
|
||||||
token_cnts_buffer = torch.empty(
|
if enable_moe_align_block_size_triton or is_hip_flag:
|
||||||
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
|
moe_align_block_size_triton(
|
||||||
)
|
topk_ids,
|
||||||
cumsum_buffer = torch.empty(
|
num_experts,
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
block_size,
|
||||||
)
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
sgl_moe_align_block_size(
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
block_size,
|
block_size,
|
||||||
sorted_ids,
|
sorted_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ops.moe_align_block_size(
|
ops.moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -854,17 +1005,18 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not_hip:
|
if is_hip_flag:
|
||||||
|
ops.moe_sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
)
|
||||||
|
else:
|
||||||
torch.sum(
|
torch.sum(
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1,
|
dim=1,
|
||||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
ops.moe_sum(
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
||||||
)
|
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user