162 lines
6.2 KiB
Python
162 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.core.sched.output import NewRequestData
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
|
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
|
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
|
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
|
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
|
from vllm.v1.worker.gpu.states import RequestState
|
|
from vllm.v1.worker.utils import AttentionGroup
|
|
|
|
|
|
class DefaultModelState(ModelState):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
model: nn.Module,
|
|
encoder_cache: EncoderCache | None,
|
|
device: torch.device,
|
|
):
|
|
self.vllm_config = vllm_config
|
|
self.model_config = vllm_config.model_config
|
|
self.scheduler_config = vllm_config.scheduler_config
|
|
self.model = model
|
|
self.device = device
|
|
|
|
self.supports_mm_inputs = encoder_cache is not None
|
|
self.max_model_len = self.model_config.max_model_len
|
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
|
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
|
self.dtype = self.model_config.dtype
|
|
|
|
if self.supports_mm_inputs:
|
|
assert encoder_cache is not None
|
|
self.encoder_cache = encoder_cache
|
|
self.encoder_runner = EncoderRunner(
|
|
model=self.model,
|
|
max_num_tokens=self.max_num_tokens,
|
|
hidden_size=self.inputs_embeds_size,
|
|
encoder_cache=encoder_cache,
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
)
|
|
|
|
self.uses_mrope = self.model_config.uses_mrope
|
|
if self.uses_mrope:
|
|
self.mrope_state = MRopeState(
|
|
max_num_reqs=self.max_num_reqs,
|
|
max_num_tokens=self.max_num_tokens,
|
|
max_model_len=self.max_model_len,
|
|
device=self.device,
|
|
)
|
|
|
|
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
|
if self.uses_mrope:
|
|
# Pre-compute M-RoPE positions for prefill.
|
|
assert new_req_data.prefill_token_ids is not None
|
|
self.mrope_state.init_prefill_mrope_positions(
|
|
req_index,
|
|
self.model, # type: ignore
|
|
new_req_data.prefill_token_ids,
|
|
mm_features=new_req_data.mm_features,
|
|
)
|
|
|
|
def apply_staged_writes(self) -> None:
|
|
if self.uses_mrope:
|
|
self.mrope_state.apply_staged_writes()
|
|
|
|
def get_mm_embeddings(
|
|
self,
|
|
scheduled_encoder_inputs: dict[str, list[int]],
|
|
input_batch: InputBatch,
|
|
req_states: RequestState,
|
|
) -> torch.Tensor:
|
|
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
|
scheduled_encoder_inputs
|
|
)
|
|
if mm_kwargs:
|
|
# Execute the multimodal encoder.
|
|
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
|
# Cache the encoder outputs by mm_hash
|
|
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
|
|
|
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
|
input_batch.req_ids,
|
|
input_batch.num_tokens,
|
|
input_batch.num_scheduled_tokens,
|
|
input_batch.query_start_loc_np,
|
|
req_states.prefill_len.np[input_batch.idx_mapping_np],
|
|
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
|
)
|
|
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
|
input_batch.input_ids, mm_embeds, is_mm_embed
|
|
)
|
|
return inputs_embeds[: input_batch.num_tokens_after_padding]
|
|
|
|
def prepare_inputs(
|
|
self, input_batch: InputBatch, req_states: RequestState
|
|
) -> dict[str, torch.Tensor | None]:
|
|
if not self.uses_mrope:
|
|
# Common case (1D positions).
|
|
return {}
|
|
|
|
# Prepare M-RoPE positions.
|
|
self.mrope_state.prepare_mrope_positions(
|
|
input_batch.idx_mapping,
|
|
input_batch.query_start_loc,
|
|
req_states.prefill_len.gpu,
|
|
req_states.num_computed_tokens.gpu,
|
|
)
|
|
mrope_positions = self.mrope_state.mrope_positions[
|
|
:, : input_batch.num_tokens_after_padding
|
|
]
|
|
return {"positions": mrope_positions}
|
|
|
|
def prepare_dummy_inputs(
|
|
self, num_reqs: int, num_tokens: int
|
|
) -> dict[str, torch.Tensor | None]:
|
|
model_inputs = {}
|
|
if self.supports_mm_inputs:
|
|
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
|
model_inputs["inputs_embeds"] = inputs_embeds
|
|
if self.uses_mrope:
|
|
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
|
model_inputs["positions"] = mrope_positions
|
|
return model_inputs
|
|
|
|
def prepare_attn(
|
|
self,
|
|
input_batch: InputBatch,
|
|
block_tables: tuple[torch.Tensor, ...],
|
|
slot_mappings: torch.Tensor,
|
|
attn_groups: list[list[AttentionGroup]],
|
|
kv_cache_config: KVCacheConfig,
|
|
) -> dict[str, Any]:
|
|
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
|
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
|
attn_metadata = build_attn_metadata(
|
|
attn_groups=attn_groups,
|
|
num_reqs=input_batch.num_reqs,
|
|
num_tokens=input_batch.num_tokens,
|
|
query_start_loc_gpu=input_batch.query_start_loc,
|
|
query_start_loc_cpu=query_start_loc_cpu,
|
|
max_query_len=max_query_len,
|
|
seq_lens=input_batch.seq_lens,
|
|
max_seq_len=self.max_model_len,
|
|
block_tables=block_tables,
|
|
slot_mappings=slot_mappings,
|
|
kv_cache_config=kv_cache_config,
|
|
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
|
)
|
|
return attn_metadata
|