Cleanup MoE Refactor (#9223)
This commit is contained in:
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
topk_output: TopKOutput,
|
||||
moe_runner_config: MoeRunnerConfig,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||
|
||||
if self.use_flashinfer:
|
||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||
x_quant, x_scale = mxfp8_quantize(
|
||||
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
assert x_quant.shape[-1] == self.hidden_size
|
||||
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||
|
||||
top_k, router_logits = topk_output
|
||||
top_k = topk_output.topk_config.top_k
|
||||
router_logits = topk_output.router_logits
|
||||
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
None, # output2_scale_scalar
|
||||
layer.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
None, # n_group # TODO: support n_group
|
||||
None, # topk_group # TODO: support topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||
layer.num_local_experts, # local num experts
|
||||
|
||||
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
kwargs = {"hidden_states": hidden_states}
|
||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
Reference in New Issue
Block a user