update
This commit is contained in:
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
0
vllm/v1/worker/gpu/mm/__init__.py
Normal file
183
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
183
vllm/v1/worker/gpu/mm/encoder_runner.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
) -> tuple[list[str], list[tuple[str, MultiModalKwargsItem]]]:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
continue
|
||||
mm_hashes.append(mm_feature.identifier)
|
||||
mm_kwargs.append((mm_feature.modality, mm_feature.data))
|
||||
|
||||
return mm_hashes, mm_kwargs
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
total_num_scheduled_tokens: int,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
query_start_loc: np.ndarray,
|
||||
prefill_lens: np.ndarray,
|
||||
computed_prefill_lens: np.ndarray,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
is_prefilling = (computed_prefill_lens < prefill_lens).tolist()
|
||||
all_decode = not any(is_prefilling)
|
||||
if all_decode:
|
||||
# All decode requests, so no need to gather any embeddings.
|
||||
return [], torch.zeros(
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
query_start = computed_prefill_lens.tolist()
|
||||
query_end = (computed_prefill_lens + num_scheduled_tokens).tolist()
|
||||
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
is_mm_embed = torch.zeros(
|
||||
total_num_scheduled_tokens, dtype=torch.bool, device="cpu", pin_memory=True
|
||||
)
|
||||
for i, req_id in enumerate(req_ids):
|
||||
if not is_prefilling[i]:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
|
||||
if start_pos >= query_end[i]:
|
||||
# The encoder output is not needed in this step.
|
||||
break
|
||||
if start_pos + num_encoder_tokens <= query_start[i]:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
continue
|
||||
|
||||
start_idx = max(query_start[i] - start_pos, 0)
|
||||
end_idx = min(query_end[i] - start_pos, num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = query_start_loc[i] + start_pos - query_start[i]
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
|
||||
# Copy the is_mm_embed tensor to the GPU.
|
||||
is_mm_embed = is_mm_embed.to(device=self.device, non_blocking=True)
|
||||
return mm_embeds, is_mm_embed
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
self.inputs_embeds[: x.shape[0]] = x
|
||||
return self.inputs_embeds
|
||||
136
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
136
vllm/v1/worker/gpu/mm/mrope_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||
|
||||
|
||||
class MRopeState:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
|
||||
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
|
||||
# wasting a lot of CPU memory.
|
||||
self.prefill_mrope_positions = StagedWriteTensor(
|
||||
(max_num_reqs * 3, max_model_len),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
uva_instead_of_gpu=True,
|
||||
)
|
||||
self.prefill_mrope_delta = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
|
||||
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
def init_prefill_mrope_positions(
|
||||
self,
|
||||
req_idx: int,
|
||||
mrope_model: SupportsMRoPE,
|
||||
prefill_token_ids: list[int],
|
||||
mm_features: list,
|
||||
) -> None:
|
||||
prefill_mrope_positions, prefill_mrope_delta = (
|
||||
mrope_model.get_mrope_input_positions(prefill_token_ids, mm_features)
|
||||
)
|
||||
for i in range(3):
|
||||
pos = prefill_mrope_positions[i].tolist()
|
||||
self.prefill_mrope_positions.stage_write(3 * req_idx + i, 0, pos)
|
||||
self.prefill_mrope_delta.np[req_idx] = prefill_mrope_delta
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.prefill_mrope_positions.apply_write()
|
||||
self.prefill_mrope_delta.copy_to_uva()
|
||||
|
||||
def prepare_mrope_positions(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
prefill_lens: torch.Tensor,
|
||||
num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_prepare_mrope_positions_kernel[(num_reqs,)](
|
||||
self.mrope_positions,
|
||||
self.mrope_positions.stride(0),
|
||||
self.prefill_mrope_positions.gpu,
|
||||
3 * self.max_model_len,
|
||||
self.max_model_len,
|
||||
self.prefill_mrope_delta.gpu,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
prefill_lens,
|
||||
num_computed_tokens,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_mrope_positions_kernel(
|
||||
mrope_positions_ptr,
|
||||
mrope_positions_stride,
|
||||
prefill_mrope_positions_ptr,
|
||||
prefill_mrope_positions_stride0,
|
||||
prefill_mrope_positions_stride1,
|
||||
prefill_mrope_delta_ptr,
|
||||
idx_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
prefill_lens_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
is_prefill = num_computed < prefill_len
|
||||
|
||||
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
mrope_delta = tl.load(prefill_mrope_delta_ptr + req_state_idx)
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
orig_pos = num_computed + block
|
||||
|
||||
for j in tl.static_range(3):
|
||||
if is_prefill:
|
||||
# Read from pre-computed M-RoPE positions.
|
||||
pos = tl.load(
|
||||
prefill_mrope_positions_ptr
|
||||
+ req_state_idx * prefill_mrope_positions_stride0
|
||||
+ j * prefill_mrope_positions_stride1
|
||||
+ orig_pos,
|
||||
mask=mask,
|
||||
)
|
||||
else:
|
||||
# Apply M-RoPE delta.
|
||||
pos = orig_pos + mrope_delta
|
||||
tl.store(
|
||||
mrope_positions_ptr + j * mrope_positions_stride + query_start + block,
|
||||
pos,
|
||||
mask=mask,
|
||||
)
|
||||
Reference in New Issue
Block a user