CI: skip special token for engine token ids unit test (#2648)
This commit is contained in:
@@ -19,20 +19,23 @@ class TestEngineTokenIds(unittest.TestCase):
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
]
|
]
|
||||||
|
|
||||||
sampling_params = {"temperature": 0, "top_p": 0.95}
|
sampling_params = {"temperature": 0, "top_p": 0.95}
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
for prompt, output in zip(prompts, outputs):
|
for prompt, output in zip(prompts, outputs):
|
||||||
# SGLang's input_ids has a start token, so we remove it for comparison.
|
deocode_input = tokenizer.decode(
|
||||||
deocode_input = tokenizer.decode(output["input_ids"][1:])
|
output["input_ids"], skip_special_tokens=True
|
||||||
assert (
|
)
|
||||||
deocode_input in prompt
|
assert (deocode_input in prompt) or (
|
||||||
|
prompt in deocode_input
|
||||||
), f"Decode input: {deocode_input} mismatch for: {prompt}"
|
), f"Decode input: {deocode_input} mismatch for: {prompt}"
|
||||||
|
|
||||||
# SGLang's output_ids does not have a start token.
|
deocode_output = tokenizer.decode(
|
||||||
deocode_output = tokenizer.decode(output["output_ids"])
|
output["output_ids"], skip_special_tokens=True
|
||||||
assert (
|
)
|
||||||
deocode_output in output["text"]
|
assert (deocode_output in output["text"]) or (
|
||||||
|
output["text"] in deocode_output
|
||||||
), f"Decode output: {deocode_output} mismatch for: {output['text']}"
|
), f"Decode output: {deocode_output} mismatch for: {output['text']}"
|
||||||
|
|
||||||
llm.shutdown()
|
llm.shutdown()
|
||||||
|
|||||||
Reference in New Issue
Block a user