Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
This commit is contained in:
@@ -15,12 +15,14 @@
|
|||||||
# Adapted from llama2.py
|
# Adapted from llama2.py
|
||||||
# Modify details for the adaptation of Qwen2 model.
|
# Modify details for the adaptation of Qwen2 model.
|
||||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
import logging
|
||||||
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import (
|
from sglang.srt.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
kv_cache_scales_loader,
|
kv_cache_scales_loader,
|
||||||
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
|
|||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MLP(nn.Module):
|
class Qwen2MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.pp_group = get_pp_group()
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
if self.pp_group.is_first_rank:
|
||||||
quant_config=quant_config,
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
config.vocab_size,
|
||||||
)
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
||||||
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
||||||
self.layers = make_layers(
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda idx, prefix: decoder_layer_type(
|
lambda idx, prefix: decoder_layer_type(
|
||||||
layer_id=idx,
|
layer_id=idx,
|
||||||
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
),
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
prefix=add_prefix("layers", prefix),
|
prefix=add_prefix("layers", prefix),
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
if self.pp_group.is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self.config, "scale_emb"):
|
if hasattr(self.config, "scale_emb"):
|
||||||
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
if input_embeds is None:
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if self.pp_group.is_first_rank:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
assert pp_proxy_tensors is not None
|
||||||
residual = None
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
for i in range(len(self.layers)):
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions,
|
||||||
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
|
|||||||
forward_batch,
|
forward_batch,
|
||||||
residual,
|
residual,
|
||||||
)
|
)
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
if not self.pp_group.is_last_rank:
|
||||||
|
return PPProxyTensors(
|
||||||
|
{
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
# If this function is called, it should always initialize KV cache scale
|
# If this function is called, it should always initialize KV cache scale
|
||||||
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2Model(
|
self.model = Qwen2Model(
|
||||||
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
get_embedding: bool = False,
|
get_embedding: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(
|
||||||
if not get_embedding:
|
input_ids,
|
||||||
return self.logits_processor(
|
positions,
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
forward_batch,
|
||||||
)
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if not get_embedding:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.pooler(hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
if name in params_dict.keys():
|
||||||
weight_loader(param, loaded_weight)
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|||||||
@@ -16,9 +16,10 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
||||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -26,6 +27,7 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, make_layers
|
||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
|
||||||
)
|
|
||||||
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
||||||
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
|
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
|
||||||
self.layers = make_layers(
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda idx, prefix: decoder_layer_type(
|
lambda idx, prefix: decoder_layer_type(
|
||||||
layer_id=idx,
|
layer_id=idx,
|
||||||
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
),
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
prefix=add_prefix("layers", prefix),
|
prefix=add_prefix("layers", prefix),
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
if self.pp_group.is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -562,20 +577,35 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
if input_embeds is None:
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if self.pp_group.is_first_rank:
|
||||||
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
assert pp_proxy_tensors is not None
|
||||||
residual = None
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
for i in range(len(self.layers)):
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
expert_distribution_recorder.set_current_layer(i)
|
expert_distribution_recorder.set_current_layer(i)
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual
|
positions, hidden_states, forward_batch, residual
|
||||||
)
|
)
|
||||||
if hidden_states.shape[0] != 0:
|
if not self.pp_group.is_last_rank:
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
return PPProxyTensors(
|
||||||
|
{
|
||||||
|
"hidden_states": hidden_states,
|
||||||
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if hidden_states.shape[0] != 0:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2MoeModel(
|
self.model = Qwen2MoeModel(
|
||||||
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> LogitsProcessorOutput:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
) -> torch.Tensor:
|
||||||
return self.logits_processor(
|
hidden_states = self.model(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids,
|
||||||
|
positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
if name in params_dict.keys():
|
||||||
weight_loader = getattr(
|
param = params_dict[name]
|
||||||
param, "weight_loader", default_weight_loader
|
weight_loader = getattr(
|
||||||
)
|
param, "weight_loader", default_weight_loader
|
||||||
weight_loader(param, loaded_weight)
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen2MoeForCausalLM
|
EntryClass = Qwen2MoeForCausalLM
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# Adapted from qwen2.py
|
# Adapted from qwen2.py
|
||||||
|
|
||||||
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
@@ -7,6 +8,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
||||||
from sglang.srt.models.qwen2 import Qwen2Model
|
from sglang.srt.models.qwen2 import Qwen2Model
|
||||||
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
|
|||||||
|
|
||||||
Qwen3Config = None
|
Qwen3Config = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Attention(nn.Module):
|
class Qwen3Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen3Model(
|
self.model = Qwen3Model(
|
||||||
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
get_embedding: bool = False,
|
get_embedding: bool = False,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(
|
||||||
if not get_embedding:
|
input_ids,
|
||||||
return self.logits_processor(
|
positions,
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
forward_batch,
|
||||||
)
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if not get_embedding:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.pooler(hidden_states, forward_batch)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||||
continue
|
continue
|
||||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
if name in params_dict.keys():
|
||||||
weight_loader(param, loaded_weight)
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
def get_embed_and_head(self):
|
def get_embed_and_head(self):
|
||||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -28,6 +29,7 @@ from torch import nn
|
|||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||||
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
|
|||||||
|
|
||||||
Qwen3MoeConfig = None
|
Qwen3MoeConfig = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Qwen3MoeSparseMoeBlock(nn.Module):
|
class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen3MoeModel(
|
self.model = Qwen3MoeModel(
|
||||||
@@ -536,12 +542,31 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> LogitsProcessorOutput:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
) -> torch.Tensor:
|
||||||
return self.logits_processor(
|
hidden_states = self.model(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids,
|
||||||
|
positions,
|
||||||
|
forward_batch,
|
||||||
|
input_embeds,
|
||||||
|
pp_proxy_tensors=pp_proxy_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|||||||
if name not in params_dict:
|
if name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
if name in params_dict.keys():
|
||||||
weight_loader = getattr(
|
param = params_dict[name]
|
||||||
param, "weight_loader", default_weight_loader
|
weight_loader = getattr(
|
||||||
)
|
param, "weight_loader", default_weight_loader
|
||||||
weight_loader(param, loaded_weight)
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Parameter {name} not found in params_dict")
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen3MoeForCausalLM
|
EntryClass = Qwen3MoeForCausalLM
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
|
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
|
||||||
|
python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency
|
||||||
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
|
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
|
|||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwenPPAccuracy(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
|
||||||
|
cls.model_name = "Qwen/Qwen3-8B" # replace with your Qwen Model if needed
|
||||||
|
|
||||||
|
def run_gsm8k_test(self, pp_size):
|
||||||
|
process = popen_launch_server(
|
||||||
|
self.model_name,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--pp-size",
|
||||||
|
pp_size,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
256,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
time.sleep(5)
|
||||||
|
return metrics
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_baseline_accuracy(self):
|
||||||
|
metrics = self.run_gsm8k_test(pp_size=1)
|
||||||
|
print(f"[Qwen Baseline] {metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.74)
|
||||||
|
|
||||||
|
def test_pp_consistency(self):
|
||||||
|
baseline = self.run_gsm8k_test(pp_size=1)
|
||||||
|
pp_metrics = self.run_gsm8k_test(pp_size=2)
|
||||||
|
|
||||||
|
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
|
||||||
|
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
pp_metrics["accuracy"],
|
||||||
|
baseline["accuracy"],
|
||||||
|
delta=0.01,
|
||||||
|
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFixedBugs(unittest.TestCase):
|
class TestFixedBugs(unittest.TestCase):
|
||||||
def test_chunked_prefill_with_small_bs(self):
|
def test_chunked_prefill_with_small_bs(self):
|
||||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
|||||||
Reference in New Issue
Block a user