add dual stream for qwen2_moe (#10252)
This commit is contained in:
@@ -65,10 +65,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -122,11 +124,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
layer_id: int,
|
layer_id: int,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
|
self.alt_stream = alt_stream
|
||||||
if self.tp_size > config.num_experts:
|
if self.tp_size > config.num_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
@@ -168,6 +172,37 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
self.shared_expert = None
|
self.shared_expert = None
|
||||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
|
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
||||||
|
shared_output = None
|
||||||
|
if self.shared_expert is not None:
|
||||||
|
shared_output = self.shared_expert(hidden_states)
|
||||||
|
if self.shared_expert_gate is not None:
|
||||||
|
shared_output = (
|
||||||
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
||||||
|
)
|
||||||
|
return shared_output
|
||||||
|
|
||||||
|
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
|
return self.experts(hidden_states, topk_output)
|
||||||
|
|
||||||
|
def forward_normal_dual_stream(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.alt_stream.wait_stream(current_stream)
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
router_output = self._forward_router_experts(hidden_states)
|
||||||
|
|
||||||
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
|
||||||
|
return router_output, shared_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -176,18 +211,20 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
shared_output = None
|
|
||||||
if self.shared_expert is not None:
|
|
||||||
shared_output = self.shared_expert(hidden_states)
|
|
||||||
if self.shared_expert_gate is not None:
|
|
||||||
shared_output = (
|
|
||||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
|
||||||
)
|
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
router_logits, _ = self.gate(hidden_states)
|
if (
|
||||||
topk_output = self.topk(hidden_states, router_logits)
|
self.alt_stream is not None
|
||||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
and hidden_states.shape[0] > 0
|
||||||
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
|
):
|
||||||
|
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
||||||
|
hidden_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
final_hidden_states = self._forward_router_experts(hidden_states)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1 and not use_reduce_scatter:
|
if self.tp_size > 1 and not use_reduce_scatter:
|
||||||
@@ -346,6 +383,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
alt_stream=alt_stream,
|
||||||
prefix=add_prefix("mlp", prefix),
|
prefix=add_prefix("mlp", prefix),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -528,8 +566,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||||
self.model = Qwen2MoeModel(
|
self.model = Qwen2MoeModel(
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
config,
|
||||||
|
quant_config,
|
||||||
|
prefix=add_prefix("model", prefix),
|
||||||
|
alt_stream=alt_stream,
|
||||||
)
|
)
|
||||||
self.lm_head = ParallelLMHead(
|
self.lm_head = ParallelLMHead(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user