[MoE] Enable renormalize=False in Triton kernels (#8735)
This commit is contained in:
@@ -183,15 +183,13 @@ class TopK(CustomOp):
|
|||||||
*,
|
*,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
sm_first: bool = False, # only used for triton kernels topk
|
|
||||||
) -> TopKOutput:
|
) -> TopKOutput:
|
||||||
if self.use_triton_kernels:
|
if self.use_triton_kernels:
|
||||||
return triton_kernels_topk(
|
# renormalize=True is equivalent to sm_first=False
|
||||||
router_logits=router_logits,
|
routing_data, gather_idx, scatter_idx = routing(
|
||||||
topk=self.top_k,
|
router_logits, self.top_k, sm_first=not self.renormalize
|
||||||
renormalize=self.renormalize,
|
|
||||||
sm_first=sm_first,
|
|
||||||
)
|
)
|
||||||
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
||||||
else:
|
else:
|
||||||
torch_native = False
|
torch_native = False
|
||||||
return select_experts(
|
return select_experts(
|
||||||
@@ -647,22 +645,6 @@ def biased_grouped_topk_cpu(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def triton_kernels_topk(
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool = False,
|
|
||||||
sm_first: bool = False,
|
|
||||||
) -> TritonKernelTopKOutput:
|
|
||||||
"""Top-K routing for Triton kernels MoE."""
|
|
||||||
assert not renormalize, "Triton kernels topk doesn't support renormalize"
|
|
||||||
routing_data, gather_idx, scatter_idx = routing(
|
|
||||||
logits=router_logits,
|
|
||||||
n_expts_act=topk,
|
|
||||||
sm_first=sm_first,
|
|
||||||
)
|
|
||||||
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
|
||||||
|
|
||||||
|
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
biased_grouped_topk = biased_grouped_topk_cpu
|
biased_grouped_topk = biased_grouped_topk_cpu
|
||||||
grouped_topk = grouped_topk_cpu
|
grouped_topk = grouped_topk_cpu
|
||||||
|
|||||||
Reference in New Issue
Block a user