added support for tied weights in qwen pipeline parallelism (#6546)

This commit is contained in:
Shenggui Li
2025-05-25 15:00:56 +08:00
committed by GitHub
parent 1a39979993
commit 3f23d8cdf1
4 changed files with 134 additions and 20 deletions

View File

@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
)
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
cls.model_name = (
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
)
def run_gsm8k_test(self, pp_size):
process = popen_launch_server(
self.model_name,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
pp_size,
"--chunked-prefill-size",
256,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
time.sleep(5)
return metrics
finally:
kill_process_tree(process.pid)
def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.39)
def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertAlmostEqual(
pp_metrics["accuracy"],
baseline["accuracy"],
delta=0.01,
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
)
class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST