Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
|
||||
python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency
|
||||
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
|
||||
"""
|
||||
|
||||
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
class TestQwenPPAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
|
||||
cls.model_name = "Qwen/Qwen3-8B" # replace with your Qwen Model if needed
|
||||
|
||||
def run_gsm8k_test(self, pp_size):
|
||||
process = popen_launch_server(
|
||||
self.model_name,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--pp-size",
|
||||
pp_size,
|
||||
"--chunked-prefill-size",
|
||||
256,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
time.sleep(5)
|
||||
return metrics
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_baseline_accuracy(self):
|
||||
metrics = self.run_gsm8k_test(pp_size=1)
|
||||
print(f"[Qwen Baseline] {metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.74)
|
||||
|
||||
def test_pp_consistency(self):
|
||||
baseline = self.run_gsm8k_test(pp_size=1)
|
||||
pp_metrics = self.run_gsm8k_test(pp_size=2)
|
||||
|
||||
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
|
||||
|
||||
self.assertAlmostEqual(
|
||||
pp_metrics["accuracy"],
|
||||
baseline["accuracy"],
|
||||
delta=0.01,
|
||||
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
|
||||
)
|
||||
|
||||
|
||||
class TestFixedBugs(unittest.TestCase):
|
||||
def test_chunked_prefill_with_small_bs(self):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
Reference in New Issue
Block a user