enable aiter_biased_grouped_topk kernel (#7423)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
|
get_bool_env_var,
|
||||||
get_compiler_backend,
|
get_compiler_backend,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
@@ -38,6 +39,7 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
|
|
||||||
@@ -46,6 +48,11 @@ if _is_cuda:
|
|||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import topk_softmax
|
from sgl_kernel import topk_softmax
|
||||||
|
if _use_aiter:
|
||||||
|
try:
|
||||||
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_torch_native(
|
def fused_topk_torch_native(
|
||||||
@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
|
|||||||
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
||||||
)
|
)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
elif _use_aiter:
|
||||||
|
token = gating_output.shape[0]
|
||||||
|
device = gating_output.device
|
||||||
|
assert (
|
||||||
|
hidden_states.shape[0] == gating_output.shape[0]
|
||||||
|
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
||||||
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
||||||
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
||||||
|
aiter_biased_grouped_topk(
|
||||||
|
gating_output,
|
||||||
|
correction_bias,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
renormalize,
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
return topk_weights, topk_ids
|
||||||
else:
|
else:
|
||||||
biased_grouped_topk_fn = (
|
biased_grouped_topk_fn = (
|
||||||
torch.compile(
|
torch.compile(
|
||||||
|
|||||||
@@ -421,7 +421,7 @@ class CudaGraphRunner:
|
|||||||
empty_cache=False,
|
empty_cache=False,
|
||||||
)
|
)
|
||||||
capture_range.set_description(
|
capture_range.set_description(
|
||||||
f"Capturing batches ({avail_mem=:.2f} GB)"
|
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch_model(
|
with patch_model(
|
||||||
|
|||||||
@@ -388,7 +388,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
if not _is_cuda:
|
if not _is_cuda and not _use_aiter:
|
||||||
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|||||||
Reference in New Issue
Block a user