[optimize] add two stream norm for qwen3 (#7740)

Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
Yi Zhang
2025-07-04 00:59:17 +08:00
committed by GitHub
parent 646cef2e2e
commit 264dc6e744
4 changed files with 54 additions and 10 deletions

View File

@@ -25,15 +25,17 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cuda
Qwen3Config = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class Qwen3Attention(nn.Module):
@@ -51,6 +53,7 @@ class Qwen3Attention(nn.Module):
rms_norm_eps: float = None,
attention_bias: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -119,15 +122,27 @@ class Qwen3Attention(nn.Module):
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
self.alt_stream = alt_stream
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k
@@ -153,6 +168,7 @@ class Qwen3DecoderLayer(nn.Module):
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -173,6 +189,7 @@ class Qwen3DecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
attention_bias=config.attention_bias,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
@@ -234,11 +251,13 @@ class Qwen3Model(Qwen2Model):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3DecoderLayer,
alt_stream=alt_stream,
)