main add ascend scheduler support multimodal (#2844)
### What this PR does / why we need it?
On main, AscendScheduler does not support Multimodels, becuse of lacking
of scheduled_encoder_inputs which is need on multimodels inference
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
vLLM version: main@93e28e6862669e3b5cf47cea9f782a65ec47e155
- vLLM version: v0.10.2rc2
- vLLM main:
15b8fef453
---------
Signed-off-by: fan2956 <zhoufan53@huawei.com>
Co-authored-by: zhoufan2956 <zhoufan2956@163.com>
This commit is contained in:
@@ -75,13 +75,6 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
str(context.exception),
|
str(context.exception),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_not_implemented_multimodal(self):
|
|
||||||
with self.assertRaises(NotImplementedError) as context:
|
|
||||||
AscendSchedulerConfig.initialize_from_config(
|
|
||||||
SchedulerConfig(is_multimodal_model=True), {})
|
|
||||||
self.assertIn("currently AscendScheduler only supports LLM models",
|
|
||||||
str(context.exception))
|
|
||||||
|
|
||||||
def test_not_implemented_send_delta_data(self):
|
def test_not_implemented_send_delta_data(self):
|
||||||
with self.assertRaises(NotImplementedError) as context:
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
AscendSchedulerConfig.initialize_from_config(
|
AscendSchedulerConfig.initialize_from_config(
|
||||||
@@ -118,6 +111,11 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||||
self.assertEqual(ascend_config.encoder_cache_size, 8192)
|
self.assertEqual(ascend_config.encoder_cache_size, 8192)
|
||||||
|
|
||||||
|
def test_valid_config_with_multimodal(self):
|
||||||
|
config = AscendSchedulerConfig.initialize_from_config(
|
||||||
|
SchedulerConfig(is_multimodal_model=True), {})
|
||||||
|
self.assertTrue(config.is_multimodal_model)
|
||||||
|
|
||||||
def test_valid_config_with_chunked_prefill(self):
|
def test_valid_config_with_chunked_prefill(self):
|
||||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||||
self.basic_scheduler_config,
|
self.basic_scheduler_config,
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from unittest.mock import MagicMock, patch
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||||
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256
|
from vllm.utils import sha256
|
||||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||||
@@ -49,11 +50,23 @@ def create_requests(
|
|||||||
prompt_logprobs=prompt_logprobs)
|
prompt_logprobs=prompt_logprobs)
|
||||||
requests = []
|
requests = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
|
mm_features = []
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
for j, position in enumerate(mm_position):
|
||||||
|
identifier = f"hash{i}_{j}"
|
||||||
|
mm_feature = MultiModalFeatureSpec(
|
||||||
|
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||||
|
mm_position=position,
|
||||||
|
identifier=identifier,
|
||||||
|
modality="image")
|
||||||
|
mm_features.append(mm_feature)
|
||||||
request = Request(request_id=f"{i}",
|
request = Request(request_id=f"{i}",
|
||||||
prompt_token_ids=[i] * num_tokens,
|
prompt_token_ids=[i] * num_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
|
mm_features=mm_features if mm_features else None,
|
||||||
block_hasher=get_request_block_hasher(
|
block_hasher=get_request_block_hasher(
|
||||||
block_size, hash_fn))
|
block_size, hash_fn))
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
@@ -86,7 +99,7 @@ class TestAscendScheduler(TestBase):
|
|||||||
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
||||||
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
||||||
def create_scheduler(self, mock_compute_encoder_budget):
|
def create_scheduler(self, mock_compute_encoder_budget):
|
||||||
mock_compute_encoder_budget.return_value = [10, 20]
|
mock_compute_encoder_budget.return_value = [100, 100]
|
||||||
use_kv_connector = False
|
use_kv_connector = False
|
||||||
block_size = 16
|
block_size = 16
|
||||||
|
|
||||||
@@ -233,6 +246,39 @@ class TestAscendScheduler(TestBase):
|
|||||||
for i, request in enumerate(requests):
|
for i, request in enumerate(requests):
|
||||||
self.assertEqual(scheduler.running[i], request)
|
self.assertEqual(scheduler.running[i], request)
|
||||||
|
|
||||||
|
def test_schedule_multimodal_requests(self):
|
||||||
|
scheduler = self.create_scheduler()
|
||||||
|
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mm_positions = [[PlaceholderRange(offset=i, length=10)]
|
||||||
|
for i in range(10)]
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=10,
|
||||||
|
mm_positions=mm_positions,
|
||||||
|
)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||||
|
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||||
|
self.assertEqual(len(output.finished_req_ids), 0)
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
|
|
||||||
|
# Verify all requests are scheduled.
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
self.assertEqual(num_tokens,
|
||||||
|
len(requests[int(req_id)].prompt_token_ids))
|
||||||
|
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
|
||||||
|
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
||||||
|
assert len(encoder_input) == 1
|
||||||
|
|
||||||
|
# Verify requests moved from waiting to running
|
||||||
|
self.assertEqual(len(scheduler.waiting), 0)
|
||||||
|
self.assertEqual(len(scheduler.running), len(requests))
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
self.assertEqual(scheduler.running[i], request)
|
||||||
|
|
||||||
def test_schedule_enable_prefix_caching(self):
|
def test_schedule_enable_prefix_caching(self):
|
||||||
'''Test scheduling.
|
'''Test scheduling.
|
||||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||||
|
|||||||
@@ -71,9 +71,6 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"
|
||||||
)
|
)
|
||||||
if self.is_multimodal_model:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"currently AscendScheduler only supports LLM models.")
|
|
||||||
if self.send_delta_data:
|
if self.send_delta_data:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"currently AscendScheduler doesn't support send_delta_data.")
|
"currently AscendScheduler doesn't support send_delta_data.")
|
||||||
|
|||||||
@@ -72,6 +72,11 @@ class AscendScheduler(Scheduler):
|
|||||||
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
|
||||||
num_scheduled_tokens: dict[str, int] = {}
|
num_scheduled_tokens: dict[str, int] = {}
|
||||||
token_budget = self.max_num_scheduled_tokens
|
token_budget = self.max_num_scheduled_tokens
|
||||||
|
|
||||||
|
# Encoder-related.
|
||||||
|
scheduled_encoder_inputs: dict[str, list[int]] = {}
|
||||||
|
encoder_budget = self.max_num_encoder_input_tokens
|
||||||
|
|
||||||
# Spec decode-related.
|
# Spec decode-related.
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||||
|
|
||||||
@@ -155,6 +160,9 @@ class AscendScheduler(Scheduler):
|
|||||||
num_new_local_computed_tokens = 0
|
num_new_local_computed_tokens = 0
|
||||||
num_computed_tokens = request.num_computed_tokens
|
num_computed_tokens = request.num_computed_tokens
|
||||||
|
|
||||||
|
encoder_inputs_to_schedule = None
|
||||||
|
new_encoder_budget = encoder_budget
|
||||||
|
|
||||||
# P/D: loading remote KV, do not allocate for new work.
|
# P/D: loading remote KV, do not allocate for new work.
|
||||||
if load_kv_async:
|
if load_kv_async:
|
||||||
assert num_external_computed_tokens > 0
|
assert num_external_computed_tokens > 0
|
||||||
@@ -192,6 +200,16 @@ class AscendScheduler(Scheduler):
|
|||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
blocks = new_computed_blocks.blocks[0]
|
blocks = new_computed_blocks.blocks[0]
|
||||||
|
|
||||||
|
# Schedule encoder inputs.
|
||||||
|
if request.has_encoder_inputs:
|
||||||
|
(encoder_inputs_to_schedule, num_new_tokens,
|
||||||
|
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||||
|
request, num_computed_tokens, num_new_tokens,
|
||||||
|
encoder_budget)
|
||||||
|
if num_new_tokens == 0:
|
||||||
|
# The request cannot be scheduled.
|
||||||
|
break
|
||||||
|
|
||||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||||
blocks, watermark):
|
blocks, watermark):
|
||||||
@@ -256,6 +274,15 @@ class AscendScheduler(Scheduler):
|
|||||||
if request.num_cached_tokens < 0:
|
if request.num_cached_tokens < 0:
|
||||||
request.num_cached_tokens = num_computed_tokens
|
request.num_cached_tokens = num_computed_tokens
|
||||||
|
|
||||||
|
# Encoder-related.
|
||||||
|
if encoder_inputs_to_schedule:
|
||||||
|
scheduled_encoder_inputs[request.request_id] = (
|
||||||
|
encoder_inputs_to_schedule)
|
||||||
|
# Allocate the encoder cache.
|
||||||
|
for i in encoder_inputs_to_schedule:
|
||||||
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
|
encoder_budget = new_encoder_budget
|
||||||
|
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
if skipped_waiting_requests:
|
if skipped_waiting_requests:
|
||||||
self.waiting.extendleft(skipped_waiting_requests)
|
self.waiting.extendleft(skipped_waiting_requests)
|
||||||
@@ -287,6 +314,16 @@ class AscendScheduler(Scheduler):
|
|||||||
num_new_tokens = min(
|
num_new_tokens = min(
|
||||||
num_new_tokens,
|
num_new_tokens,
|
||||||
self.max_model_len - request.num_computed_tokens)
|
self.max_model_len - request.num_computed_tokens)
|
||||||
|
|
||||||
|
# Schedule encoder inputs.
|
||||||
|
encoder_inputs_to_schedule = None
|
||||||
|
new_encoder_budget = encoder_budget
|
||||||
|
if request.has_encoder_inputs:
|
||||||
|
(encoder_inputs_to_schedule, num_new_tokens,
|
||||||
|
new_encoder_budget) = self._try_schedule_encoder_inputs(
|
||||||
|
request, request.num_computed_tokens, num_new_tokens,
|
||||||
|
encoder_budget)
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
# constraint.
|
# constraint.
|
||||||
if self.lora_config and request.lora_request and (
|
if self.lora_config and request.lora_request and (
|
||||||
@@ -358,6 +395,15 @@ class AscendScheduler(Scheduler):
|
|||||||
scheduled_spec_decode_tokens[request.request_id] = (
|
scheduled_spec_decode_tokens[request.request_id] = (
|
||||||
request.spec_token_ids)
|
request.spec_token_ids)
|
||||||
|
|
||||||
|
# Encoder-related.
|
||||||
|
if encoder_inputs_to_schedule:
|
||||||
|
scheduled_encoder_inputs[request.request_id] = (
|
||||||
|
encoder_inputs_to_schedule)
|
||||||
|
# Allocate the encoder cache.
|
||||||
|
for i in encoder_inputs_to_schedule:
|
||||||
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
|
encoder_budget = new_encoder_budget
|
||||||
|
|
||||||
# Record scheduled LoRA requests.
|
# Record scheduled LoRA requests.
|
||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
@@ -401,7 +447,7 @@ class AscendScheduler(Scheduler):
|
|||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||||
# finished_req_ids is an existing state in the scheduler,
|
# finished_req_ids is an existing state in the scheduler,
|
||||||
# instead of being newly scheduled in this step.
|
# instead of being newly scheduled in this step.
|
||||||
|
|||||||
Reference in New Issue
Block a user