Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
This commit is contained in:
@@ -15,12 +15,14 @@
|
||||
# Adapted from llama2.py
|
||||
# Modify details for the adaptation of Qwen2 model.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
kv_cache_scales_loader,
|
||||
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
|
||||
Qwen2Config = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
||||
self.layers = make_layers(
|
||||
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: decoder_layer_type(
|
||||
layer_id=idx,
|
||||
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
if self.pp_group.is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if hasattr(self.config, "scale_emb"):
|
||||
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||
if self.pp_group.is_first_rank:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert pp_proxy_tensors is not None
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if not self.pp_group.is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
}
|
||||
)
|
||||
else:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(
|
||||
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
@@ -16,9 +16,10 @@
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -26,6 +27,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.pp_group = get_pp_group()
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
||||
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
|
||||
self.layers = make_layers(
|
||||
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: decoder_layer_type(
|
||||
layer_id=idx,
|
||||
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
if self.pp_group.is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -562,20 +577,35 @@ class Qwen2MoeModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||
if self.pp_group.is_first_rank:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
assert pp_proxy_tensors is not None
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
expert_distribution_recorder.set_current_layer(i)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if not self.pp_group.is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual,
|
||||
}
|
||||
)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(
|
||||
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
|
||||
EntryClass = Qwen2MoeForCausalLM
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Adapted from qwen2.py
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
@@ -7,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
|
||||
|
||||
Qwen3Config = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
def __init__(
|
||||
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen3Model(
|
||||
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
@@ -28,6 +29,7 @@ from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
|
||||
|
||||
Qwen3MoeConfig = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pp_group = get_pp_group()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen3MoeModel(
|
||||
@@ -536,12 +542,31 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
input_embeds,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
)
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
@property
|
||||
def start_layer(self):
|
||||
return self.model.start_layer
|
||||
|
||||
@property
|
||||
def end_layer(self):
|
||||
return self.model.end_layer
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
layer_id = get_layer_id(name)
|
||||
if (
|
||||
layer_id is not None
|
||||
and hasattr(self.model, "start_layer")
|
||||
and (
|
||||
layer_id < self.model.start_layer
|
||||
or layer_id >= self.model.end_layer
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
|
||||
EntryClass = Qwen3MoeForCausalLM
|
||||
|
||||
Reference in New Issue
Block a user