diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index c59d296a6..9ac855c49 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -6,7 +6,7 @@ from torch import nn from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, @@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module): super().__init__() self.config = config + + self.pp_group = get_pp_group() + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, prefix=add_prefix("embed_tokens", prefix), ) self.mup_embedding_multiplier = config.mup_embedding_multiplier - self.start_layer, self.end_layer, self.layers = make_layers( + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, lambda idx, prefix: Phi3SmallDecoderLayer( config, @@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module): quant_config, prefix=prefix, ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix=add_prefix("layers", prefix), )