added support for tied weights in qwen pipeline parallelism (#6546)
This commit is contained in:
2
.github/workflows/pr-test.yml
vendored
2
.github/workflows/pr-test.yml
vendored
@@ -84,7 +84,7 @@ jobs:
|
|||||||
bash scripts/ci_install_dependency.sh
|
bash scripts/ci_install_dependency.sh
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
timeout-minutes: 25
|
timeout-minutes: 30
|
||||||
run: |
|
run: |
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 run_suite.py --suite per-commit-2-gpu
|
python3 run_suite.py --suite per-commit-2-gpu
|
||||||
|
|||||||
@@ -386,15 +386,36 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.model = Qwen2Model(
|
self.model = Qwen2Model(
|
||||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.model.embed_tokens
|
# handle the lm head on different pp ranks
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.lm_head = ParallelLMHead(
|
# ranks other than the last rank will have a placeholder layer
|
||||||
config.vocab_size,
|
self.lm_head = PPMissingLayer()
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
# perform weight tying for PP
|
||||||
prefix=add_prefix("lm_head", prefix),
|
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||||
)
|
if self.pp_group.is_first_rank:
|
||||||
|
self.pp_group.send(
|
||||||
|
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
emb_token_weight = self.pp_group.recv(
|
||||||
|
size=(config.vocab_size, config.hidden_size),
|
||||||
|
dtype=next(self.model.parameters()).dtype,
|
||||||
|
src=self.pp_group.first_rank,
|
||||||
|
)
|
||||||
|
self.lm_head.weight.copy_(emb_token_weight)
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@@ -470,7 +491,15 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
continue
|
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||||
|
# Handle pp weight tying here
|
||||||
|
# find the embed_tokens.weight in the weights
|
||||||
|
embed_token_weights = next(
|
||||||
|
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||||
|
)[1]
|
||||||
|
loaded_weight = embed_token_weights
|
||||||
|
else:
|
||||||
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.utils import get_layer_id
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
@@ -249,15 +249,36 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
self.model = Qwen3Model(
|
self.model = Qwen3Model(
|
||||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.model.embed_tokens
|
# handle the lm head on different pp ranks
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.lm_head = ParallelLMHead(
|
# ranks other than the last rank will have a placeholder layer
|
||||||
config.vocab_size,
|
self.lm_head = PPMissingLayer()
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
# perform weight tying for PP
|
||||||
prefix=add_prefix("lm_head", prefix),
|
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
|
||||||
)
|
if self.pp_group.is_first_rank:
|
||||||
|
self.pp_group.send(
|
||||||
|
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
emb_token_weight = self.pp_group.recv(
|
||||||
|
size=(config.vocab_size, config.hidden_size),
|
||||||
|
dtype=next(self.model.parameters()).dtype,
|
||||||
|
src=self.pp_group.first_rank,
|
||||||
|
)
|
||||||
|
self.lm_head.weight.copy_(emb_token_weight)
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@@ -330,7 +351,15 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
continue
|
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
|
||||||
|
# Handle pp weight tying here
|
||||||
|
# find the embed_tokens.weight in the weights
|
||||||
|
embed_token_weights = next(
|
||||||
|
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
|
||||||
|
)[1]
|
||||||
|
loaded_weight = embed_token_weights
|
||||||
|
else:
|
||||||
|
continue
|
||||||
if name.startswith("model.vision_tower") and name not in params_dict:
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
|
||||||
|
cls.model_name = (
|
||||||
|
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_gsm8k_test(self, pp_size):
|
||||||
|
process = popen_launch_server(
|
||||||
|
self.model_name,
|
||||||
|
self.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--pp-size",
|
||||||
|
pp_size,
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
256,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
time.sleep(5)
|
||||||
|
return metrics
|
||||||
|
finally:
|
||||||
|
kill_process_tree(process.pid)
|
||||||
|
|
||||||
|
def test_baseline_accuracy(self):
|
||||||
|
metrics = self.run_gsm8k_test(pp_size=1)
|
||||||
|
print(f"[Qwen Baseline] {metrics=}")
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.39)
|
||||||
|
|
||||||
|
def test_pp_consistency(self):
|
||||||
|
baseline = self.run_gsm8k_test(pp_size=1)
|
||||||
|
pp_metrics = self.run_gsm8k_test(pp_size=2)
|
||||||
|
|
||||||
|
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
|
||||||
|
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
pp_metrics["accuracy"],
|
||||||
|
baseline["accuracy"],
|
||||||
|
delta=0.01,
|
||||||
|
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestFixedBugs(unittest.TestCase):
|
class TestFixedBugs(unittest.TestCase):
|
||||||
def test_chunked_prefill_with_small_bs(self):
|
def test_chunked_prefill_with_small_bs(self):
|
||||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
|||||||
Reference in New Issue
Block a user