Fix torch.compile for MoE (#2033)
This commit is contained in:
@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
self.assertGreaterEqual(metrics["score"], 0.65)
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
@@ -49,8 +49,8 @@ class TestTorchCompile(unittest.TestCase):
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"ignore_eos": True,
|
||||
},
|
||||
)
|
||||
return response.json()
|
||||
@@ -66,7 +66,7 @@ class TestTorchCompile(unittest.TestCase):
|
||||
print(res["text"])
|
||||
throughput = max_tokens / (tok - tic)
|
||||
print(f"Throughput: {throughput} tokens/s")
|
||||
assert throughput >= 152
|
||||
self.assertGreaterEqual(throughput, 152)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user