Overlap shared expert and routed expert computations (#5121)
This commit is contained in:
@@ -90,7 +90,7 @@ class LlamaMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, forward_batch=None):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x)
|
||||||
|
|||||||
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
|
ForwardBatch,
|
||||||
|
ForwardMode,
|
||||||
|
PPProxyTensors,
|
||||||
|
)
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
||||||
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
||||||
|
|
||||||
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.top_k = config.num_experts_per_tok
|
self.top_k = config.num_experts_per_tok
|
||||||
|
self.device_module = torch.get_device_module()
|
||||||
|
|
||||||
intermediate_size_moe = config.intermediate_size
|
intermediate_size_moe = config.intermediate_size
|
||||||
self.router = ReplicatedLinear(
|
self.router = ReplicatedLinear(
|
||||||
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
|
|||||||
reduce_results=False, # We need to do scatter before reduce
|
reduce_results=False, # We need to do scatter before reduce
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, forward_batch: ForwardBatch):
|
||||||
|
shared_out, routed_out = self._forward_core(
|
||||||
|
hidden_states, forward_batch.forward_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
out_aD = routed_out + shared_out
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
||||||
|
|
||||||
|
return out_aD
|
||||||
|
|
||||||
|
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
||||||
|
if hidden_states.shape[0] < 4:
|
||||||
|
return self._forward_core_shared_routed_overlap(hidden_states)
|
||||||
|
else:
|
||||||
|
return self._forward_core_normal(hidden_states)
|
||||||
|
|
||||||
|
def _forward_core_normal(self, hidden_states):
|
||||||
# router_scores: [num_tokens, num_experts]
|
# router_scores: [num_tokens, num_experts]
|
||||||
router_logits, _ = self.router(hidden_states)
|
router_logits, _ = self.router(hidden_states)
|
||||||
shared_out = self.shared_expert(hidden_states)
|
shared_out = self.shared_expert(hidden_states)
|
||||||
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
|
|||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
out_aD = routed_out + shared_out
|
return shared_out, routed_out
|
||||||
|
|
||||||
if self.tp_size > 1:
|
def _forward_core_shared_routed_overlap(self, hidden_states):
|
||||||
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
alt_stream = _get_or_create_alt_stream(self.device_module)
|
||||||
|
|
||||||
return out_aD
|
alt_stream.wait_stream(self.device_module.current_stream())
|
||||||
|
|
||||||
|
shared_out = self.shared_expert(hidden_states)
|
||||||
|
|
||||||
|
with self.device_module.stream(alt_stream):
|
||||||
|
# router_scores: [num_tokens, num_experts]
|
||||||
|
router_logits, _ = self.router(hidden_states)
|
||||||
|
routed_out = self.experts(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
router_logits=router_logits,
|
||||||
|
)
|
||||||
|
self.device_module.current_stream().wait_stream(alt_stream)
|
||||||
|
|
||||||
|
return shared_out, routed_out
|
||||||
|
|
||||||
|
|
||||||
|
_alt_stream = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_alt_stream(device_module):
|
||||||
|
global _alt_stream
|
||||||
|
if _alt_stream is None:
|
||||||
|
_alt_stream = device_module.Stream()
|
||||||
|
return _alt_stream
|
||||||
|
|
||||||
|
|
||||||
class Llama4Attention(nn.Module):
|
class Llama4Attention(nn.Module):
|
||||||
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states = self.feed_forward(hidden_states)
|
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
||||||
|
|
||||||
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
||||||
# Scatter
|
# Scatter
|
||||||
|
|||||||
Reference in New Issue
Block a user