Fix Phi3 serving which was broke by earlier change (#5991)
Co-authored-by: Lifu Huang <lifu.hlf@gmail.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from torch import nn
|
|||||||
from transformers import Phi3Config
|
from transformers import Phi3Config
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
|
from sglang.srt.layers.utils import PPMissingLayer
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE,
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
prefix=add_prefix("embed_tokens", prefix),
|
prefix=add_prefix("embed_tokens", prefix),
|
||||||
)
|
)
|
||||||
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
config.num_hidden_layers,
|
config.num_hidden_layers,
|
||||||
lambda idx, prefix: Phi3SmallDecoderLayer(
|
lambda idx, prefix: Phi3SmallDecoderLayer(
|
||||||
config,
|
config,
|
||||||
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
|
|||||||
quant_config,
|
quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
),
|
),
|
||||||
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
|
pp_size=self.pp_group.world_size,
|
||||||
prefix=add_prefix("layers", prefix),
|
prefix=add_prefix("layers", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user