fuse renormal into moe topk softmax kernel python code (#7751)
Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -52,7 +52,7 @@ runtime_common = [
|
|||||||
|
|
||||||
srt = [
|
srt = [
|
||||||
"sglang[runtime_common]",
|
"sglang[runtime_common]",
|
||||||
"sgl-kernel==0.2.1",
|
"sgl-kernel==0.2.2",
|
||||||
"torch==2.7.1",
|
"torch==2.7.1",
|
||||||
"torchaudio==2.7.1",
|
"torchaudio==2.7.1",
|
||||||
"torchvision==0.22.1",
|
"torchvision==0.22.1",
|
||||||
|
|||||||
@@ -650,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"sgl-kernel",
|
"sgl-kernel",
|
||||||
"0.2.1",
|
"0.2.2",
|
||||||
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -108,38 +108,14 @@ def fused_topk(
|
|||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
)
|
)
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
token_expert_indicies = torch.empty(
|
|
||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
topk_softmax(
|
topk_softmax(
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(),
|
gating_output.float(),
|
||||||
)
|
renormalize,
|
||||||
del token_expert_indicies
|
|
||||||
|
|
||||||
return _fused_topk_postprocess(
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
renormalize=renormalize,
|
|
||||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
||||||
num_token_non_padded=num_token_non_padded,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
||||||
def _fused_topk_postprocess(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
renormalize,
|
|
||||||
expert_location_dispatch_info,
|
|
||||||
num_token_non_padded,
|
|
||||||
):
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
||||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user