[UT] refactor test_expert_load_balancer and fix broken CI (#1293)

refactor test_expert_load_balancer to keep the ut code style

This PR also fixed the break change from
https://github.com/vllm-project/vllm/pull/16188/files#diff-e2942ece30a5c580437694ffb964bfc664b510c59244c08e5921b8f5cefb4280

This is just a quick fix. We'll support embedding on V1 later

Closes: https://github.com/vllm-project/vllm-ascend/issues/1299

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-06-20 01:02:52 +08:00
committed by GitHub
parent ebb2a70dbb
commit b350edae9a
4 changed files with 205 additions and 140 deletions

View File

@@ -98,11 +98,7 @@ def create_scheduler(
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
**({
"tensors": {}
} if vllm_version_is("0.9.0") else {
"kv_cache_tensors": []
}),
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
@@ -145,8 +141,8 @@ def create_requests(num_requests: int,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
**({
"arrival_time": 0.0
} if vllm_version_is("0.9.0") else {}),
"pooling_params": None
} if not vllm_version_is("0.9.1") else {}),
)
requests.append(request)
return requests
@@ -262,7 +258,9 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running.
@@ -286,7 +284,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
@@ -337,7 +338,10 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -385,7 +389,10 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -432,7 +439,10 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -474,7 +484,10 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output, model_output)
@@ -524,7 +537,10 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output0, model_runner_output)
# Schedule the next step.
@@ -541,7 +557,10 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
scheduler.update_from_output(scheduler_output1, model_runner_output)
@@ -565,8 +584,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
if vllm_version_is("0.9.0"):
return
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
@@ -593,7 +610,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
@@ -632,7 +652,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
@@ -727,7 +750,9 @@ def make_output(scheduler: AscendScheduler):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
**({
"pooler_output": []
} if not vllm_version_is("0.9.1") else {}))
def assert_scheduler_empty(scheduler: AscendScheduler):
@@ -744,11 +769,10 @@ def assert_scheduler_empty(scheduler: AscendScheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
if not vllm_version_is("0.9.0"):
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].num_cached_block) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
@@ -789,4 +813,4 @@ def test_memory_leak():
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
assert_scheduler_empty(scheduler)