Support shared experts overlap in cutlass moe (#11611)
This commit is contained in:
@@ -800,7 +800,7 @@ class FusedMoE(torch.nn.Module):
|
||||
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs):
|
||||
origin_hidden_states_dim = hidden_states.shape[-1]
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -825,6 +825,7 @@ class FusedMoE(torch.nn.Module):
|
||||
combine_input = self.quant_method.apply(
|
||||
layer=self,
|
||||
dispatch_output=dispatch_output,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
final_hidden_states = self.dispatcher.combine(combine_input)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -1347,6 +1348,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
forward_shared_experts=None,
|
||||
alt_stream=None,
|
||||
) -> CombineInput:
|
||||
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
@@ -1418,9 +1421,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
)[0]
|
||||
if should_use_flashinfer_cutlass_moe_fp4_allgather():
|
||||
output, global_output = get_local_dp_buffer(), output
|
||||
|
||||
if forward_shared_experts is not None:
|
||||
alt_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(alt_stream):
|
||||
forward_shared_experts()
|
||||
|
||||
get_tp_group().reduce_scatterv(
|
||||
global_output, output=output, sizes=get_dp_global_num_tokens()
|
||||
)
|
||||
|
||||
if forward_shared_experts is not None:
|
||||
torch.cuda.current_stream().wait_stream(alt_stream)
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
@@ -655,6 +655,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self._enable_a2a_moe = (
|
||||
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
|
||||
)
|
||||
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -746,9 +747,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||
|
||||
if hidden_states.shape[0] > 0:
|
||||
shared_output = self._forward_shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator
|
||||
)
|
||||
if not self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = self._forward_shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator
|
||||
)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
@@ -756,7 +758,27 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = None
|
||||
topk_output = self.topk.empty_topk_output(hidden_states.device)
|
||||
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if self._fuse_shared_experts_inside_sbo:
|
||||
shared_output = None
|
||||
|
||||
def _forward_shared_experts_and_put_results():
|
||||
nonlocal shared_output
|
||||
shared_output = self._forward_shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator
|
||||
)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states,
|
||||
topk_output,
|
||||
**(
|
||||
dict(
|
||||
forward_shared_experts=_forward_shared_experts_and_put_results,
|
||||
alt_stream=self.alt_stream,
|
||||
)
|
||||
if self._fuse_shared_experts_inside_sbo
|
||||
else {}
|
||||
),
|
||||
)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
Reference in New Issue
Block a user