Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Adapted from qwen2.py
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
@@ -7,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
|
||||
|
||||
Qwen3Config = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
def __init__(
|
||||
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen3Model(
|
||||
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
Reference in New Issue
Block a user