diff --git a/python/pyproject.toml b/python/pyproject.toml index 13b3d9952..7a0f23e4f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -52,7 +52,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", - "sgl-kernel==0.2.1", + "sgl-kernel==0.2.2", "torch==2.7.1", "torchaudio==2.7.1", "torchvision==0.22.1", diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index f56c40133..5bb67d1e0 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -650,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs): if _is_cuda: assert_pkg_version( "sgl-kernel", - "0.2.1", + "0.2.2", "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`", ) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 908927b88..3c6d9eae9 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -108,38 +108,14 @@ def fused_topk( M, topk, dtype=torch.float32, device=hidden_states.device ) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) topk_softmax( topk_weights, topk_ids, - token_expert_indicies, gating_output.float(), - ) - del token_expert_indicies - - return _fused_topk_postprocess( - topk_weights=topk_weights, - topk_ids=topk_ids, - renormalize=renormalize, - expert_location_dispatch_info=expert_location_dispatch_info, - num_token_non_padded=num_token_non_padded, + renormalize, ) - -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def _fused_topk_postprocess( - topk_weights, - topk_ids, - renormalize, - expert_location_dispatch_info, - num_token_non_padded, -): - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) return topk_weights, topk_ids