Fix PP for Qwen3 MoE (#6709)
This commit is contained in:
@@ -812,9 +812,9 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
self.routed_experts_weights_of_layer = {
|
||||
layer_id: layer.mlp.get_moe_weights()
|
||||
for layer_id, layer in enumerate(self.model.layers)
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock)
|
||||
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
|
||||
for layer_id in range(self.start_layer, self.end_layer)
|
||||
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -121,7 +121,7 @@ class TestQwenPPAccuracy(unittest.TestCase):
|
||||
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
|
||||
cls.base_url = "http://127.0.0.1:23335" # different ports to avoid conflicts
|
||||
cls.model_name = (
|
||||
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
|
||||
)
|
||||
@@ -176,6 +176,62 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestQwenMoePPAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = "http://127.0.0.1:23336" # different ports to avoid conflicts
|
||||
cls.model_name = "Qwen/Qwen3-30B-A3B" # replace with your Qwen Model if needed
|
||||
|
||||
def run_gsm8k_test(self, pp_size):
|
||||
process = popen_launch_server(
|
||||
self.model_name,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--pp-size",
|
||||
pp_size,
|
||||
"--chunked-prefill-size",
|
||||
256,
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
time.sleep(5)
|
||||
return metrics
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
def test_baseline_accuracy(self):
|
||||
metrics = self.run_gsm8k_test(pp_size=1)
|
||||
print(f"[Qwen Baseline] {metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.74)
|
||||
|
||||
def test_pp_consistency(self):
|
||||
baseline = self.run_gsm8k_test(pp_size=1)
|
||||
pp_metrics = self.run_gsm8k_test(pp_size=2)
|
||||
|
||||
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
|
||||
|
||||
self.assertGreaterEqual(
|
||||
pp_metrics["accuracy"],
|
||||
baseline["accuracy"] - 0.01,
|
||||
msg=(
|
||||
f"PP accuracy dropped more than 1% compared to baseline. "
|
||||
f"Baseline: {baseline['accuracy']:.2%}, PP: {pp_metrics['accuracy']:.2%}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestFixedBugs(unittest.TestCase):
|
||||
def test_chunked_prefill_with_small_bs(self):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
|
||||
Reference in New Issue
Block a user