Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)

This commit is contained in:
libra
2025-05-18 15:42:55 +08:00
committed by GitHub
parent 01dd39bac1
commit 11553c1a37
5 changed files with 340 additions and 73 deletions

View File

@@ -1,5 +1,6 @@
# Adapted from qwen2.py
import logging
from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple
@@ -7,6 +8,7 @@ import torch
from torch import nn
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
Qwen3Config = None
logger = logging.getLogger(__name__)
class Qwen3Attention(nn.Module):
def __init__(
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = Qwen3Model(
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
else:
return self.pooler(hidden_states, forward_batch)
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight