Cleanup MoE Refactor (#9223)
This commit is contained in:
@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_output: TopKOutput,
|
topk_output: TopKOutput,
|
||||||
moe_runner_config: MoeRunnerConfig,
|
moe_runner_config: MoeRunnerConfig,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import TopKOutputChecker
|
||||||
|
|
||||||
if self.use_flashinfer:
|
if self.use_flashinfer:
|
||||||
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
|
||||||
x_quant, x_scale = mxfp8_quantize(
|
x_quant, x_scale = mxfp8_quantize(
|
||||||
@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
) # to mxfp8
|
) # to mxfp8
|
||||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||||
assert x_quant.shape[-1] == self.hidden_size
|
assert x_quant.shape[-1] == self.hidden_size
|
||||||
|
assert TopKOutputChecker.format_is_bypassed(topk_output)
|
||||||
|
|
||||||
top_k, router_logits = topk_output
|
top_k = topk_output.topk_config.top_k
|
||||||
|
router_logits = topk_output.router_logits
|
||||||
|
|
||||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||||
router_logits.to(torch.bfloat16),
|
router_logits.to(torch.bfloat16),
|
||||||
@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
None, # output2_scale_scalar
|
None, # output2_scale_scalar
|
||||||
layer.num_experts,
|
layer.num_experts,
|
||||||
top_k,
|
top_k,
|
||||||
None, # n_group
|
None, # n_group # TODO: support n_group
|
||||||
None, # topk_group
|
None, # topk_group # TODO: support topk_group
|
||||||
self.intermediate_size, # padded to multiple of 256
|
self.intermediate_size, # padded to multiple of 256
|
||||||
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
|
||||||
layer.num_local_experts, # local num experts
|
layer.num_local_experts, # local num experts
|
||||||
|
|||||||
@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
final_hidden_states_out = torch.empty_like(final_hidden_states)
|
||||||
|
|
||||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
final_hidden_states = final_hidden_states_out
|
final_hidden_states = final_hidden_states_out
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
|
||||||
|
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if not _is_cuda and not _use_aiter:
|
if not _is_cuda and not _use_aiter:
|
||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|||||||
@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
final_hidden_states = self.experts(**kwargs)
|
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
kwargs = {"hidden_states": hidden_states}
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
final_hidden_states = self.experts(**kwargs)
|
|
||||||
if not _is_cuda and not _use_aiter:
|
if not _is_cuda and not _use_aiter:
|
||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
|
|||||||
Reference in New Issue
Block a user