forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
369
vllm-v0.6.2/vllm/worker/openvino_model_runner.py
Normal file
369
vllm-v0.6.2/vllm/worker/openvino_model_runner.py
Normal file
@@ -0,0 +1,369 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, NamedTuple, Optional, Tuple
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.openvino import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ModelInput(NamedTuple):
|
||||
input_tokens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
attn_metadata: Optional[OpenVINOAttentionMetadata]
|
||||
seq_lens: List[int]
|
||||
query_lens: List[int]
|
||||
multi_modal_kwargs: BatchedTensorInputs
|
||||
|
||||
@classmethod
|
||||
def empty(cls, device):
|
||||
return ModelInput(input_tokens=torch.empty(0, device=device),
|
||||
input_positions=torch.empty(0, device=device),
|
||||
attn_metadata=None,
|
||||
seq_lens=[],
|
||||
query_lens=[],
|
||||
multi_modal_kwargs={})
|
||||
|
||||
|
||||
class OpenVINOModelRunner(ModelRunnerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ov_core: ov.Core,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.ov_core = ov_core
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
cache_config = self.cache_config
|
||||
model_config = self.model_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.multi_modal_input_mapper = self.mm_registry \
|
||||
.create_input_mapper(self.model_config)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
ov_core=self.ov_core)
|
||||
|
||||
def _prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> ModelInput:
|
||||
"""Prepare the model input based on a given sequence group.
|
||||
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
"""
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
|
||||
seq_lens: List[int] = []
|
||||
past_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
|
||||
subsequence_begins: List[int] = []
|
||||
block_indices: List[int] = []
|
||||
block_indices_begins: List[int] = []
|
||||
|
||||
# initialize beginning of prefix sums
|
||||
subsequence_begins.append(0)
|
||||
block_indices_begins.append(0)
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return ModelInput.empty(self.device)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_id in seq_ids:
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if (self.scheduler_config is not None
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
and not (computed_block_nums is None
|
||||
or computed_block_nums == [])):
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching "
|
||||
"now.")
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
if is_prompt:
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
else:
|
||||
# get_num_computed_tokens is incorrect for spec decoding.
|
||||
# So, we should have a special logic here.
|
||||
# TODO(sang): Fix it.
|
||||
computed_len = seq_data.get_len() - 1
|
||||
|
||||
seq_len = min(
|
||||
seq_data.get_len(),
|
||||
computed_len + seq_group_metadata.token_chunk_size,
|
||||
)
|
||||
if is_prompt:
|
||||
tokens = seq_data.get_token_ids()[computed_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = [seq_data.get_last_token_id()]
|
||||
|
||||
# Prefix cache was hit.
|
||||
# Prefix is not supported with sliding_window
|
||||
prefix_cache_hit = (computed_block_nums is not None
|
||||
and len(computed_block_nums) > 0
|
||||
and self.sliding_window is None
|
||||
and is_prompt)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
if prefix_cache_hit:
|
||||
assert computed_block_nums is not None
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
tokens = tokens[computed_len:]
|
||||
elif (self.scheduler_config.chunked_prefill_enabled
|
||||
or not is_prompt):
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
# chunked prefill or decode
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
if self.sliding_window is not None:
|
||||
# chunked prefill doesn't support sliding window.
|
||||
assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
else:
|
||||
# Only happens when memory profiling runs.
|
||||
block_table = []
|
||||
else:
|
||||
# prompt phase w/o prefix_caching, chunked_prefill
|
||||
pass
|
||||
|
||||
block_indices.extend(block_table)
|
||||
block_indices_begins.append(block_indices_begins[-1] +
|
||||
len(block_table))
|
||||
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
if self.sliding_window is not None and not is_prompt:
|
||||
seq_len = min(seq_len, self.sliding_window)
|
||||
computed_len = seq_len - 1
|
||||
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
query_len = seq_len - computed_len
|
||||
query_lens.append(query_len)
|
||||
|
||||
input_tokens.extend(tokens)
|
||||
positions_range = range(computed_len, seq_len)
|
||||
input_positions.extend(list(positions_range))
|
||||
|
||||
past_lens.append(computed_len)
|
||||
subsequence_begins.append(subsequence_begins[-1] + query_len)
|
||||
|
||||
if is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
else:
|
||||
assert (
|
||||
query_len == 1
|
||||
), "seq_len: {}, computed_len: {}, query_len: {}".format(
|
||||
seq_len, computed_len, query_len)
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
# NOTE: mm_data only includes the subset of multi-modal
|
||||
# items that intersect with the current prefill positions.
|
||||
mm_data, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
if self.mm_registry.has_processor(self.model_config):
|
||||
mm_kwargs = mm_data
|
||||
else:
|
||||
mm_kwargs = self.multi_modal_input_mapper(
|
||||
mm_data,
|
||||
seq_group_metadata.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map, )
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
assert max_query_len > 0, "query_lens: {}".format(query_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
past_lens_tensor = torch.tensor(past_lens,
|
||||
dtype=torch.int32,
|
||||
device=self.device) # type: ignore
|
||||
subsequence_begins_tensor = torch.tensor(
|
||||
subsequence_begins, dtype=torch.int32,
|
||||
device=self.device) # type: ignore
|
||||
block_indices_tensor = torch.tensor(block_indices,
|
||||
dtype=torch.int32,
|
||||
device=self.device) # type: ignore
|
||||
block_indices_begins_tensor = torch.tensor(
|
||||
block_indices_begins, dtype=torch.int32,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
max_context_len = max(seq_lens)
|
||||
max_context_len_tensor = torch.tensor(
|
||||
max_context_len, dtype=torch.int32,
|
||||
device=self.device) # type: ignore
|
||||
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = self.attn_backend.make_openvino_metadata(
|
||||
past_lens=past_lens_tensor,
|
||||
subsequence_begins=subsequence_begins_tensor,
|
||||
block_indices=block_indices_tensor,
|
||||
block_indices_begins=block_indices_begins_tensor,
|
||||
max_context_len=max_context_len_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return ModelInput(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
)
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
|
||||
SamplingMetadata, BatchedTensorInputs]:
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
multi_modal_kwargs,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
multi_modal_kwargs,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
sampling_metadata,
|
||||
multi_modal_kwargs,
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
input_tokens,
|
||||
"positions":
|
||||
input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
attn_metadata,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
}
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
def prepare_model_input(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user