diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 5cd121235..5507182a7 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -43,16 +43,29 @@ class TestPyTorchSamplingBackend(unittest.TestCase): assert metrics["score"] >= 0.65 def test_greedy(self): - response_single = requests.post( - self.base_url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": 32, + + first_text = None + + # ensure the answer is identical across single response + for _ in range(5): + response_single = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + }, }, - }, - ).json() + ).json() + text = response_single["text"] + if first_text is None: + first_text = text + + assert text == first_text, f'"{text}" is not identical to "{first_text}"' + + first_text = None + response_batch = requests.post( self.base_url + "/generate", json={ @@ -63,10 +76,13 @@ class TestPyTorchSamplingBackend(unittest.TestCase): }, }, ).json() - text = response_single["text"] - print(text) + + # ensure the answer is identical among the batch for i in range(10): - assert response_batch[i]["text"] == text + text = response_batch[i]["text"] + if first_text is None: + first_text = text + assert text == first_text, f'"{text}" is not identical to "{first_text}"' if __name__ == "__main__":