Fix torch.compile for MoE (#2033)

This commit is contained in:
Lianmin Zheng
2024-11-14 01:30:24 -08:00
committed by GitHub
parent b275ce0043
commit c3eac1b010
10 changed files with 89 additions and 12 deletions

View File

@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
self.assertGreaterEqual(metrics["score"], 0.65)
def run_decode(self, max_new_tokens):
response = requests.post(
@@ -49,8 +49,8 @@ class TestTorchCompile(unittest.TestCase):
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"ignore_eos": True,
},
)
return response.json()
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
print(res["text"])
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 152
self.assertGreaterEqual(throughput, 152)
if __name__ == "__main__":