Fix Phi3 serving which was broke by earlier change (#5991)
Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import Phi3Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
self.pp_group = get_pp_group()
|
||||
if self.pp_group.is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda idx, prefix: Phi3SmallDecoderLayer(
|
||||
config,
|
||||
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
pp_rank=self.pp_group.rank_in_group,
|
||||
pp_size=self.pp_group.world_size,
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user