Fix some ci issue and refactor modelrunner (#2445)

### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.10.0
- vLLM main:
4d9c61993a

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
Mengqing Cao
2025-08-20 09:01:04 +08:00
committed by GitHub
parent 955411611c
commit 1327f9be1c
28 changed files with 1612 additions and 1020 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Optional
import pytest
import torch
import torch.nn.functional as F
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -66,7 +66,7 @@ def create_sampling_metadata(
output_token_ids=[],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessorManager())
logitsprocs=LogitsProcessors())
########################### Tests for Greedy Sampling ###################

View File

@@ -9,6 +9,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionState,
AscendMetadata,
CommonAttentionState)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
class TestAscendAttentionBackend(TestBase):
@@ -67,8 +68,12 @@ class TestAscendAttentionBackend(TestBase):
class TestAscendAttentionMetadataBuilder(TestBase):
def setUp(self):
self.mock_runner = MagicMock()
self.builder = AscendAttentionMetadataBuilder(self.mock_runner)
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
self.mock_device)
def test_reorder_batch(self):
mock_input_batch = MagicMock()
@@ -86,31 +91,28 @@ class TestAscendAttentionMetadataBuilder(TestBase):
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
mock_npu_format_cast,
mock_ascend_metadata):
num_reqs = 2
num_actual_tokens = 10
max_query_len = 5
self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((10, 10))
self.mock_runner.attn_state = AscendAttentionState.PrefillNoCache
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 3, 7])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache)
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(
num_reqs,
num_actual_tokens,
max_query_len,
)
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@@ -120,51 +122,53 @@ class TestAscendAttentionMetadataBuilder(TestBase):
def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_is_310p, mock_nd_to_nz_spec,
mock_npu_format_cast, mock_ascend_metadata):
num_reqs = 3
num_actual_tokens = 15
max_query_len = 6
self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((15, 15))
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
num_reqs = 3
num_actual_tokens = 15
max_query_len = 6
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_model = MagicMock()
self.mock_runner.input_batch.block_table = [MagicMock()]
self.mock_runner.input_batch.block_table[
0].get_device_tensor.return_value = torch.zeros((10, 10))
self.mock_runner.max_num_blocks_per_req = 10
self.mock_runner.query_lens = torch.tensor([2, 3, 4])
self.mock_runner.seq_lens_cpu = torch.tensor([4, 5, 6])
self.mock_runner.slot_mapping_cpu = torch.tensor(range(20))
self.mock_runner.device = 'cpu:0'
self.mock_runner.attn_mask = torch.ones((15, 15))
self.mock_runner.attn_state = AscendAttentionState.ChunkedPrefill
self.mock_runner.query_start_loc_cpu = torch.tensor([0, 2, 5, 9])
self.builder.build(num_reqs, num_actual_tokens, max_query_len)
self.builder.build(common_attn_metadata, mock_model)
class TestAscendAttentionBackendImpl(TestBase):

View File

@@ -1,6 +1,5 @@
from unittest.mock import MagicMock, patch
import numpy as np
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
@@ -12,6 +11,7 @@ from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
class TestAscendMLABackend(TestBase):
@@ -178,40 +178,41 @@ class TestAscendMLAMetadata(TestBase):
class TestAscendMLAMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
runner = MagicMock()
runner.scheduler_config = MagicMock()
runner.model_config = MagicMock()
runner.scheduler_config.max_num_seqs = 4
runner.model_config.max_model_len = 1024
runner.model_config.get_head_size.return_value = 64
runner.model_config.dtype = torch.float16
runner.chunked_prefill_enabled = False
runner.device = "cpu"
runner.block_size = 16
runner.decode_token_per_req = 1
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
self.assertEqual(builder.runner, runner)
self.assertEqual(builder.block_size, runner.block_size)
self.assertEqual(builder.chunked_prefill_enabled,
runner.chunked_prefill_enabled)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
self.assertEqual(builder.torchair_graph_enabled, True)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
runner = MagicMock()
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
@@ -230,22 +231,23 @@ class TestAscendMLAMetadataBuilder(TestBase):
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
self.assertEqual(builder._num_decodes, 4)
self.assertEqual(builder._num_prefills, 0)
self.assertEqual(builder._num_decode_tokens, 7)
self.assertEqual(builder._num_prefill_tokens, 0)
input_batch.swap_states.assert_not_called()
def test_reorder_batch_without_torchair_graph(self):
ascend_config = MagicMock()
runner = MagicMock()
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(runner)
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
@@ -264,10 +266,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
self.assertEqual(builder._num_decodes, 2)
self.assertEqual(builder._num_prefills, 2)
self.assertEqual(builder._num_decode_tokens, 2)
self.assertEqual(builder._num_prefill_tokens, 5)
input_batch.swap_states.assert_called_once_with(1, 2)
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
@@ -275,11 +273,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
@@ -292,11 +292,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
@@ -310,11 +312,13 @@ class TestAscendMLAMetadataBuilder(TestBase):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32)
runner.chunked_prefill_enabled = False
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner)
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
@@ -329,38 +333,45 @@ class TestAscendMLAMetadataBuilder(TestBase):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
runner = MagicMock()
runner.model_config = MagicMock()
runner.device = "cpu"
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
runner.model_config.get_head_size.return_value = 64
runner.chunked_prefill_enabled = False
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
runner.dtype = torch.float16
runner.decode_token_per_req = 1
builder = AscendMLAMetadataBuilder(runner=runner,
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'
builder = AscendMLAMetadataBuilder(mock_vllm_config,
mock_device,
metadata_cls=AscendMLAMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
metadata = builder.build_torchair_graph_dummy(3, 3)
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=runner.dtype,
device=runner.device)
dtype=torch.float16,
device=mock_device)
self.assertIsInstance(metadata, AscendMLAMetadata)
self.assertEqual(metadata.num_input_tokens, 3)

View File

@@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
@@ -68,7 +68,6 @@ def make_output(scheduler):
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -296,7 +295,6 @@ class TestAscendScheduler(TestBase):
},
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -352,7 +350,6 @@ class TestAscendScheduler(TestBase):
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -407,7 +404,6 @@ class TestAscendScheduler(TestBase):
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -451,7 +447,6 @@ class TestAscendScheduler(TestBase):
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -509,7 +504,6 @@ class TestAscendScheduler(TestBase):
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -526,7 +520,6 @@ class TestAscendScheduler(TestBase):
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -586,13 +579,14 @@ class TestAscendScheduler(TestBase):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
@@ -633,7 +627,6 @@ class TestAscendScheduler(TestBase):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@@ -674,10 +667,6 @@ class TestAscendScheduler(TestBase):
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block), 0)
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
0)
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
free_block_queue.num_free_blocks)
self.assertEqual(

View File

@@ -42,7 +42,8 @@ def test_basic_lifecycle():
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
do_remote_prefill=True,
block_size=BLOCK_SIZE)
scheduler.add_request(request)
request_id = request.request_id

View File

@@ -10,6 +10,8 @@ import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@@ -39,7 +41,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
@@ -118,6 +119,9 @@ def create_scheduler(
)
_none_hash_initialized = False
def create_request(
request_id: int,
num_tokens: int = 10,
@@ -126,8 +130,15 @@ def create_request(
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
kv_transfer_params: Optional[dict[str, Any]] = None
@@ -164,6 +175,7 @@ def create_request(
"pooling_params": []
} if not vllm_version_is("0.9.1") else {}),
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
req.kv_transfer_params = kv_transfer_params
return req
@@ -196,7 +208,6 @@ def create_model_runner_output(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@@ -184,6 +184,11 @@ class MockQuantMethod(nn.Module):
class MockFusedMoEMethod(FusedMoEMethodBase):
# TODO(bnell): also pass quant_config?
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,

View File

@@ -536,10 +536,10 @@ class TestNPUPlatform(TestBase):
mock_config = MagicMock(spec=ModelConfig)
self.assertTrue(self.platform.supports_v1(mock_config))
def test_get_piecewise_backend_cls_returns_correct_value(self):
def test_get_static_graph_wrapper_cls_returns_correct_value(self):
self.assertEqual(
self.platform.get_piecewise_backend_cls(),
"vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend",
self.platform.get_static_graph_wrapper_cls(),
"vllm_ascend.compilation.acl_graph.ACLGraphWrapper",
)
@patch("torch.distributed.is_hccl_available", return_value=True)

View File

@@ -1,161 +1,371 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from tests.ut.base import TestBase
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
MAX_NUM_PROMPT_TOKENS = 64
def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
return CachedRequestState(
req_id=req_id,
prompt_token_ids=prompt,
mm_kwargs=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,
generator=None,
block_ids=([], ),
num_computed_tokens=0,
output_token_ids=output,
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove
def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
class TestInputBatch(TestBase):
def _create_sampling_params():
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def setUp(self):
self.max_num_reqs = 10
self.max_model_len = 32
self.max_num_batched_tokens = 132
self.vocab_size = 1000
self.device = torch.device("cpu")
self.block_sizes = [128]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_batched_tokens,
device=self.device,
pin_memory=False,
vocab_size=self.vocab_size,
block_sizes=self.block_sizes,
)
self.cached_request_state = mock_cached_request_state()
def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
)
def test_shapes_and_defaults(self):
# torch tensor shape assertions
self.assertEqual(self.input_batch.token_ids_cpu_tensor.shape,
(self.max_num_reqs, self.max_model_len))
self.assertEqual(self.input_batch.temperature.shape,
(self.max_num_reqs, ))
self.assertEqual(self.input_batch.top_k.shape, (self.max_num_reqs, ))
self.assertEqual(self.input_batch.min_p_cpu_tensor.shape,
(self.max_num_reqs, ))
# numpy shape assertions
self.assertEqual(self.input_batch.token_ids_cpu.shape,
(self.max_num_reqs, self.max_model_len))
self.assertEqual(self.input_batch.num_tokens.shape,
(self.max_num_reqs, ))
self.assertEqual(self.input_batch.num_tokens.shape,
(self.max_num_reqs, ))
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.
# type assertions
self.assertIsInstance(self.input_batch.greedy_reqs, set)
self.assertIsInstance(self.input_batch.req_id_to_index, dict)
self.assertIsInstance(self.input_batch.sampling_metadata,
SamplingMetadata)
self.assertIsInstance(self.input_batch.block_table,
MultiGroupBlockTable)
self.assertIsNone(self.input_batch.allowed_token_ids_mask)
self.assertIsNone(self.input_batch.allowed_token_ids_mask_cpu_tensor)
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
def test_add_request(self):
# case1: add a new req
self.input_batch.add_request(self.cached_request_state)
self.assertIn(self.cached_request_state.req_id,
self.input_batch.req_id_to_index)
req_index = self.input_batch.req_id_to_index[
self.cached_request_state.req_id]
self.assertEqual(self.input_batch.num_prompt_tokens[req_index],
len(self.cached_request_state.prompt_token_ids))
self.assertEqual(self.input_batch.num_tokens[req_index],
self.cached_request_state.num_tokens)
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# case2: add an existing req, maybe need update
self.cached_request_state.output_token_ids.extend([7, 8, 9])
self.cached_request_state.num_computed_tokens += 3
cached_index = self.input_batch.req_id_to_index[
self.cached_request_state.req_id]
self.input_batch.add_request(self.cached_request_state, cached_index)
# check if this index in the input_batch is updated
# This np arrat "token_ids_cpu" should be filled with prompt_token_ids + output_token_ids
self.assertTrue(
np.all(self.input_batch.token_ids_cpu[
cached_index, :self.cached_request_state.num_tokens]),
msg=f"Token IDs at index {cached_index} did not update correctly.")
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# case3: add req that greater than max_num_reqs
with self.assertRaises(AssertionError):
self.input_batch.add_request(self.cached_request_state,
req_index=self.max_num_reqs)
# Remove some requests
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# case4: add req that out of max_model_len
long_prompt = list(range(self.max_model_len + 1))
long_request = mock_cached_request_state(req_id="2",
prompt=long_prompt,
output=[10])
with self.assertRaises(ValueError) as cm:
self.input_batch.add_request(long_request)
self.assertIn("could not broadcast", str(cm.exception))
# Compact the input batch
input_batch.condense()
def test_remove_request(self):
self.input_batch.add_request(self.cached_request_state)
req_index = self.input_batch.remove_request(
self.cached_request_state.req_id)
self.assertIsNotNone(req_index)
self.assertNotIn(self.cached_request_state.req_id,
self.input_batch.req_id_to_index)
self.assertIsNone(self.input_batch._req_ids[req_index])
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
def test_condense(self):
# Let's say we have some requests like below
# Index Req ID
# 0 1
# 1 2
# 2 3
# 3 4
for i in range(4):
request = mock_cached_request_state(req_id=str(i + 1))
self.input_batch.add_request(request)
removed_req_indices = []
id_to_remove = ["2", "4"] # IDs to remove
for req_id in id_to_remove:
removed_index = self.input_batch.remove_request(req_id)
if removed_index is not None:
removed_req_indices.append(removed_index)
self.assertEqual(len(removed_req_indices), len(id_to_remove))
self.input_batch.condense(sorted(removed_req_indices, reverse=True))
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs,
req_ids_retained,
input_batch.req_id_to_index,
device=torch.device(device))
# Check if the remaining requests are condensed correctly
indices = [
self.input_batch.req_id_to_index[req_id] for req_id in ["1", "3"]
]
self.assertTrue(all(idx < self.input_batch.num_reqs
for idx in indices))
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))
for i in range(self.input_batch.num_reqs):
self.assertIsNotNone(self.input_batch._req_ids[i])
for i in range(self.input_batch.num_reqs,
len(self.input_batch._req_ids)):
self.assertIsNone(self.input_batch._req_ids[i])
# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids
for req_id in ["1", "3"]:
idx = self.input_batch.req_id_to_index[req_id]
tokens = self.input_batch.token_ids_cpu[idx]
self.assertTrue(
tokens.any(),
f"Tokens at index {idx} for req {req_id} should not be all zero"
)
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])
for req_index in range(batch_size):
req = reordered_reqs[req_index]
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)