added support for tied weights in qwen pipeline parallelism (#6546)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
||||
bash scripts/ci_install_dependency.sh
|
||||
|
||||
- name: Run test
|
||||
timeout-minutes: 25
|
||||
timeout-minutes: 30
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite per-commit-2-gpu
|
||||
|
||||
@@ -386,15 +386,36 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.model = Qwen2Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
|
||||
# handle the lm head on different pp ranks
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
# ranks other than the last rank will have a placeholder layer
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
# perform weight tying for PP
|
||||
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||
if self.pp_group.is_first_rank:
|
||||
self.pp_group.send(
|
||||
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||
)
|
||||
else:
|
||||
emb_token_weight = self.pp_group.recv(
|
||||
size=(config.vocab_size, config.hidden_size),
|
||||
dtype=next(self.model.parameters()).dtype,
|
||||
src=self.pp_group.first_rank,
|
||||
)
|
||||
self.lm_head.weight.copy_(emb_token_weight)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -470,7 +491,15 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||
# Handle pp weight tying here
|
||||
# find the embed_tokens.weight in the weights
|
||||
embed_token_weights = next(
|
||||
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||
)[1]
|
||||
loaded_weight = embed_token_weights
|
||||
else:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import get_layer_id
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -249,15 +249,36 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
self.model = Qwen3Model(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
|
||||
# handle the lm head on different pp ranks
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
# ranks other than the last rank will have a placeholder layer
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
# perform weight tying for PP
|
||||
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||
if self.pp_group.is_first_rank:
|
||||
self.pp_group.send(
|
||||
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||
)
|
||||
else:
|
||||
emb_token_weight = self.pp_group.recv(
|
||||
size=(config.vocab_size, config.hidden_size),
|
||||
dtype=next(self.model.parameters()).dtype,
|
||||
src=self.pp_group.first_rank,
|
||||
)
|
||||
self.lm_head.weight.copy_(emb_token_weight)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@@ -330,7 +351,15 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||
# Handle pp weight tying here
|
||||
# find the embed_tokens.weight in the weights
|
||||
embed_token_weights = next(
|
||||
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||
)[1]
|
||||
loaded_weight = embed_token_weights
|
||||
else:
|
||||
continue
|
||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||
continue
|
||||
|
||||
|
||||
@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
|
||||
cls.model_name = (
|
||||
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
|
||||
)
|
||||
|
||||
def run_gsm8k_test(self, pp_size):
|
||||
process = popen_launch_server(
|
||||
self.model_name,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--pp-size",
|
||||
pp_size,
|
||||
"--chunked-prefill-size",
|
||||
256,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
time.sleep(5)
|
||||
return metrics
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_baseline_accuracy(self):
|
||||
metrics = self.run_gsm8k_test(pp_size=1)
|
||||
print(f"[Qwen Baseline] {metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.39)
|
||||
|
||||
def test_pp_consistency(self):
|
||||
baseline = self.run_gsm8k_test(pp_size=1)
|
||||
pp_metrics = self.run_gsm8k_test(pp_size=2)
|
||||
|
||||
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
|
||||
|
||||
self.assertAlmostEqual(
|
||||
pp_metrics["accuracy"],
|
||||
baseline["accuracy"],
|
||||
delta=0.01,
|
||||
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
|
||||
)
|
||||
|
||||
|
||||
class TestFixedBugs(unittest.TestCase):
|
||||
def test_chunked_prefill_with_small_bs(self):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
Reference in New Issue
Block a user