fuse renormal into moe topk softmax kernel python code (#7751)
Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -52,7 +52,7 @@ runtime_common = [
|
||||
|
||||
srt = [
|
||||
"sglang[runtime_common]",
|
||||
"sgl-kernel==0.2.1",
|
||||
"sgl-kernel==0.2.2",
|
||||
"torch==2.7.1",
|
||||
"torchaudio==2.7.1",
|
||||
"torchvision==0.22.1",
|
||||
|
||||
@@ -650,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
if _is_cuda:
|
||||
assert_pkg_version(
|
||||
"sgl-kernel",
|
||||
"0.2.1",
|
||||
"0.2.2",
|
||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||
)
|
||||
|
||||
|
||||
@@ -108,38 +108,14 @@ def fused_topk(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
token_expert_indicies,
|
||||
gating_output.float(),
|
||||
)
|
||||
del token_expert_indicies
|
||||
|
||||
return _fused_topk_postprocess(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
renormalize=renormalize,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def _fused_topk_postprocess(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
renormalize,
|
||||
expert_location_dispatch_info,
|
||||
num_token_non_padded,
|
||||
):
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user