main add ascend scheduler support multimodal (#2844)

### What this PR does / why we need it?
On main, AscendScheduler does not support Multimodels, becuse of lacking
of scheduled_encoder_inputs which is need on multimodels inference

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?
vLLM version: main@93e28e6862669e3b5cf47cea9f782a65ec47e155

- vLLM version: v0.10.2rc2
- vLLM main:
15b8fef453

---------

Signed-off-by: fan2956 <zhoufan53@huawei.com>
Co-authored-by: zhoufan2956 <zhoufan2956@163.com>
This commit is contained in:
fan2956
2025-09-14 09:38:51 +08:00
committed by GitHub
parent 0747a6e68c
commit c5a502fd2e
4 changed files with 100 additions and 13 deletions

View File

@@ -75,13 +75,6 @@ class TestAscendSchedulerConfig(TestBase):
str(context.exception), str(context.exception),
) )
def test_not_implemented_multimodal(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
SchedulerConfig(is_multimodal_model=True), {})
self.assertIn("currently AscendScheduler only supports LLM models",
str(context.exception))
def test_not_implemented_send_delta_data(self): def test_not_implemented_send_delta_data(self):
with self.assertRaises(NotImplementedError) as context: with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config( AscendSchedulerConfig.initialize_from_config(
@@ -118,6 +111,11 @@ class TestAscendSchedulerConfig(TestBase):
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192) self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192) self.assertEqual(ascend_config.encoder_cache_size, 8192)
def test_valid_config_with_multimodal(self):
config = AscendSchedulerConfig.initialize_from_config(
SchedulerConfig(is_multimodal_model=True), {})
self.assertTrue(config.is_multimodal_model)
def test_valid_config_with_chunked_prefill(self): def test_valid_config_with_chunked_prefill(self):
ascend_config = AscendSchedulerConfig.initialize_from_config( ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, self.basic_scheduler_config,

View File

@@ -6,7 +6,8 @@ from unittest.mock import MagicMock, patch
import torch import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256 from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
@@ -49,11 +50,23 @@ def create_requests(
prompt_logprobs=prompt_logprobs) prompt_logprobs=prompt_logprobs)
requests = [] requests = []
for i in range(num_requests): for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
request = Request(request_id=f"{i}", request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
pooling_params=None, pooling_params=None,
mm_features=mm_features if mm_features else None,
block_hasher=get_request_block_hasher( block_hasher=get_request_block_hasher(
block_size, hash_fn)) block_size, hash_fn))
requests.append(request) requests.append(request)
@@ -86,7 +99,7 @@ class TestAscendScheduler(TestBase):
@patch("vllm.config.VllmConfig.__post_init__", MagicMock()) @patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget): def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [10, 20] mock_compute_encoder_budget.return_value = [100, 100]
use_kv_connector = False use_kv_connector = False
block_size = 16 block_size = 16
@@ -233,6 +246,39 @@ class TestAscendScheduler(TestBase):
for i, request in enumerate(requests): for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request) self.assertEqual(scheduler.running[i], request)
def test_schedule_multimodal_requests(self):
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
mm_positions = [[PlaceholderRange(offset=i, length=10)]
for i in range(10)]
requests = create_requests(
num_requests=10,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
assert len(encoder_input) == 1
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_enable_prefix_caching(self): def test_schedule_enable_prefix_caching(self):
'''Test scheduling. '''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs

View File

@@ -71,9 +71,6 @@ class AscendSchedulerConfig(SchedulerConfig):
raise NotImplementedError( raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}" f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
) )
if self.is_multimodal_model:
raise NotImplementedError(
"currently AscendScheduler only supports LLM models.")
if self.send_delta_data: if self.send_delta_data:
raise NotImplementedError( raise NotImplementedError(
"currently AscendScheduler doesn't support send_delta_data.") "currently AscendScheduler doesn't support send_delta_data.")

View File

@@ -72,6 +72,11 @@ class AscendScheduler(Scheduler):
req_to_new_blocks: dict[str, KVCacheBlocks] = {} req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
@@ -155,6 +160,9 @@ class AscendScheduler(Scheduler):
num_new_local_computed_tokens = 0 num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
# P/D: loading remote KV, do not allocate for new work. # P/D: loading remote KV, do not allocate for new work.
if load_kv_async: if load_kv_async:
assert num_external_computed_tokens > 0 assert num_external_computed_tokens > 0
@@ -192,6 +200,16 @@ class AscendScheduler(Scheduler):
assert num_new_tokens > 0 assert num_new_tokens > 0
blocks = new_computed_blocks.blocks[0] blocks = new_computed_blocks.blocks[0]
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
watermark = getattr(self.scheduler_config, "watermark", 0.01) watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(request, num_new_tokens, if not self._check_watermark_for_prefill(request, num_new_tokens,
blocks, watermark): blocks, watermark):
@@ -256,6 +274,15 @@ class AscendScheduler(Scheduler):
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests) self.waiting.extendleft(skipped_waiting_requests)
@@ -287,6 +314,16 @@ class AscendScheduler(Scheduler):
num_new_tokens = min( num_new_tokens = min(
num_new_tokens, num_new_tokens,
self.max_model_len - request.num_computed_tokens) self.max_model_len - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_budget)
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if self.lora_config and request.lora_request and ( if self.lora_config and request.lora_request and (
@@ -358,6 +395,15 @@ class AscendScheduler(Scheduler):
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids) request.spec_token_ids)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Record scheduled LoRA requests. # Record scheduled LoRA requests.
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
@@ -401,7 +447,7 @@ class AscendScheduler(Scheduler):
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs={}, scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks, num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler, # finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step. # instead of being newly scheduled in this step.