diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 3f6dba752..6c902655d 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -190,6 +190,7 @@ class Qwen2DecoderLayer(nn.Module): layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -253,6 +254,7 @@ class Qwen2Model(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -280,6 +282,7 @@ class Qwen2Model(nn.Module): config=config, quant_config=quant_config, prefix=prefix, + alt_stream=alt_stream, ), pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0968ba0f4..95f0fcb70 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -291,6 +291,7 @@ class Qwen2MoeDecoderLayer(nn.Module): layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -393,6 +394,7 @@ class Qwen2MoeModel(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -418,6 +420,7 @@ class Qwen2MoeModel(nn.Module): config=config, quant_config=quant_config, prefix=prefix, + alt_stream=alt_stream, ), pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 2035b6c11..c42ac2af0 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -25,15 +25,17 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_cuda Qwen3Config = None logger = logging.getLogger(__name__) +_is_cuda = is_cuda() class Qwen3Attention(nn.Module): @@ -51,6 +53,7 @@ class Qwen3Attention(nn.Module): rms_norm_eps: float = None, attention_bias: bool = False, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -119,15 +122,27 @@ class Qwen3Attention(nn.Module): layer_id=layer_id, prefix=add_prefix("attn", prefix), ) + self.alt_stream = alt_stream def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) q = q_by_head.view(q.shape) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) return q, k @@ -153,6 +168,7 @@ class Qwen3DecoderLayer(nn.Module): layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -173,6 +189,7 @@ class Qwen3DecoderLayer(nn.Module): rms_norm_eps=config.rms_norm_eps, attention_bias=config.attention_bias, prefix=add_prefix("self_attn", prefix), + alt_stream=alt_stream, ) self.mlp = Qwen3MLP( hidden_size=self.hidden_size, @@ -234,11 +251,13 @@ class Qwen3Model(Qwen2Model): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: + alt_stream = torch.cuda.Stream() if _is_cuda else None super().__init__( config=config, quant_config=quant_config, prefix=prefix, decoder_layer_type=Qwen3DecoderLayer, + alt_stream=alt_stream, ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c76326ec0..5a2844438 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -67,6 +67,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -76,11 +77,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher -from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty Qwen3MoeConfig = None logger = logging.getLogger(__name__) +_is_cuda = is_cuda() class Qwen3MoeSparseMoeBlock(nn.Module): @@ -352,6 +354,7 @@ class Qwen3MoeAttention(nn.Module): attention_bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -421,15 +424,27 @@ class Qwen3MoeAttention(nn.Module): self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.alt_stream = alt_stream def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) q = q_by_head.view(q.shape) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) return q, k @@ -489,6 +504,7 @@ class Qwen3MoeDecoderLayer(nn.Module): layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -514,6 +530,7 @@ class Qwen3MoeDecoderLayer(nn.Module): attention_bias=attention_bias, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), + alt_stream=alt_stream, ) self.layer_id = layer_id @@ -657,11 +674,13 @@ class Qwen3MoeModel(Qwen2MoeModel): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: + alt_stream = torch.cuda.Stream() if _is_cuda else None super().__init__( config=config, quant_config=quant_config, prefix=prefix, decoder_layer_type=Qwen3MoeDecoderLayer, + alt_stream=alt_stream, )