[UT] refactor test_expert_load_balancer and fix broken CI (#1293)

refactor test_expert_load_balancer to keep the ut code style

This PR also fixed the break change from
https://github.com/vllm-project/vllm/pull/16188/files#diff-e2942ece30a5c580437694ffb964bfc664b510c59244c08e5921b8f5cefb4280

This is just a quick fix. We'll support embedding on V1 later

Closes: https://github.com/vllm-project/vllm-ascend/issues/1299

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-20 01:02:52 +08:00
committed by GitHub
parent ebb2a70dbb
commit b350edae9a
4 changed files with 205 additions and 140 deletions

View File

@@ -31,6 +31,7 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
@@ -130,6 +131,9 @@ def create_requests(num_requests: int,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
**({
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
)
requests.append(request)
return requests
@@ -237,7 +241,10 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -283,7 +290,10 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -328,7 +338,10 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -369,7 +382,10 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)