update sgl-kernel for EP: python part (#8550)
This commit is contained in:
@@ -54,7 +54,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.2.7",
|
"sgl-kernel==0.2.8",
|
||||||
"torch==2.7.1",
|
"torch==2.7.1",
|
||||||
"torchaudio==2.7.1",
|
"torchaudio==2.7.1",
|
||||||
"torchvision==0.22.1",
|
"torchvision==0.22.1",
|
||||||
|
|||||||
@@ -648,7 +648,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.2.7",
|
"0.2.8",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -568,7 +568,7 @@ def moe_align_block_size(
|
|||||||
- The padding ensures that the total number of tokens is now divisible
|
- The padding ensures that the total number of tokens is now divisible
|
||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
|
||||||
sorted_ids = torch.empty(
|
sorted_ids = torch.empty(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
@@ -578,13 +578,9 @@ def moe_align_block_size(
|
|||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
|
|
||||||
|
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
|
||||||
cumsum_buffer = torch.empty(
|
cumsum_buffer = torch.empty(
|
||||||
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
|
||||||
token_cnts_buffer = torch.empty(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Threshold based on benchmark results
|
# Threshold based on benchmark results
|
||||||
@@ -594,12 +590,11 @@ def moe_align_block_size(
|
|||||||
|
|
||||||
sgl_moe_align_block_size(
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts + 1,
|
||||||
block_size,
|
block_size,
|
||||||
sorted_ids,
|
sorted_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
fuse_sorted_ids_padding,
|
fuse_sorted_ids_padding,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user