Clean up unit tests (#1020)
This commit is contained in:
@@ -461,8 +461,11 @@ class ModelTpServer:
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
if self.tokenizer is None:
|
||||
for i, req in enumerate(batch.reqs):
|
||||
next_token_ids.extend(req.sampling_params.stop_token_ids)
|
||||
next_token_ids = []
|
||||
for req in batch.reqs:
|
||||
next_token_ids.append(
|
||||
next(iter(req.sampling_params.stop_token_ids))
|
||||
)
|
||||
else:
|
||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_decode_json():
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
|
||||
def test_expert_answer():
|
||||
def test_expert_answer(check_answer=True):
|
||||
@sgl.function
|
||||
def expert_answer(s, question):
|
||||
s += "Question: " + question + "\n"
|
||||
@@ -167,7 +167,9 @@ def test_expert_answer():
|
||||
)
|
||||
|
||||
ret = expert_answer.run(question="What is the capital of France?", temperature=0.1)
|
||||
assert "paris" in ret.text().lower()
|
||||
|
||||
if check_answer:
|
||||
assert "paris" in ret.text().lower(), f"Answer: {ret.text()}"
|
||||
|
||||
|
||||
def test_tool_use():
|
||||
|
||||
Reference in New Issue
Block a user