added support for tied weights in qwen pipeline parallelism (#6546)
This commit is contained in:
@@ -386,15 +386,36 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
|
||||
# handle the lm head on different pp ranks
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
# ranks other than the last rank will have a placeholder layer
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
# perform weight tying for PP
|
||||
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||
if self.pp_group.is_first_rank:
|
||||
self.pp_group.send(
|
||||
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||
)
|
||||
else:
|
||||
emb_token_weight = self.pp_group.recv(
|
||||
size=(config.vocab_size, config.hidden_size),
|
||||
dtype=next(self.model.parameters()).dtype,
|
||||
src=self.pp_group.first_rank,
|
||||
)
|
||||
self.lm_head.weight.copy_(emb_token_weight)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -470,7 +491,15 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||
# Handle pp weight tying here
|
||||
# find the embed_tokens.weight in the weights
|
||||
embed_token_weights = next(
|
||||
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||
)[1]
|
||||
loaded_weight = embed_token_weights
|
||||
else:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -249,15 +249,36 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
self.model = Qwen3Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
|
||||
# handle the lm head on different pp ranks
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
# ranks other than the last rank will have a placeholder layer
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
# perform weight tying for PP
|
||||
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||
if self.pp_group.is_first_rank:
|
||||
self.pp_group.send(
|
||||
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||
)
|
||||
else:
|
||||
emb_token_weight = self.pp_group.recv(
|
||||
size=(config.vocab_size, config.hidden_size),
|
||||
dtype=next(self.model.parameters()).dtype,
|
||||
src=self.pp_group.first_rank,
|
||||
)
|
||||
self.lm_head.weight.copy_(emb_token_weight)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -330,7 +351,15 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||
# Handle pp weight tying here
|
||||
# find the embed_tokens.weight in the weights
|
||||
embed_token_weights = next(
|
||||
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||
)[1]
|
||||
loaded_weight = embed_token_weights
|
||||
else:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user