Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
class AsyncPoolingOutput(AsyncModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
pooler_output: torch.Tensor,
|
||||
is_valid: torch.Tensor | None,
|
||||
main_stream: torch.cuda.Stream,
|
||||
copy_stream: torch.cuda.Stream,
|
||||
copy_event: torch.cuda.Event,
|
||||
):
|
||||
self.model_runner_output = model_runner_output
|
||||
self.pooler_output = pooler_output
|
||||
self.is_valid = is_valid
|
||||
self.copy_event = copy_event
|
||||
|
||||
with stream(copy_stream, main_stream):
|
||||
copy_stream.wait_stream(main_stream)
|
||||
self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True)
|
||||
if self.is_valid is not None:
|
||||
self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True)
|
||||
else:
|
||||
self.is_valid_cpu = None
|
||||
self.copy_event.record(copy_stream)
|
||||
|
||||
def get_output(self) -> ModelRunnerOutput:
|
||||
self.copy_event.synchronize()
|
||||
pooler_output = self.pooler_output_cpu.unbind(dim=0)
|
||||
if self.is_valid_cpu is not None:
|
||||
is_valid_cpu = self.is_valid_cpu.tolist()
|
||||
for i, is_valid in enumerate(is_valid_cpu):
|
||||
if not is_valid:
|
||||
pooler_output[i] = None
|
||||
self.model_runner_output.pooler_output = pooler_output
|
||||
return self.model_runner_output
|
||||
|
||||
|
||||
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
|
||||
return x.to("cpu", non_blocking=True).numpy()
|
||||
|
||||
|
||||
@@ -119,6 +119,10 @@ class BlockTables:
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
@@ -150,7 +154,14 @@ class BlockTables:
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
|
||||
# This is because the padding logic is complex and kernels may access beyond
|
||||
# the requested range.
|
||||
self.slot_mappings.fill_(PAD_SLOT_ID)
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
# Therefore, this method must return the persistent tensor
|
||||
# with the same memory address as that used during the model's forward pass,
|
||||
# rather than allocating a new tensor.
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -29,13 +27,11 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
@@ -88,9 +84,8 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -113,24 +108,23 @@ class CudaGraphManager:
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -143,11 +137,7 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -164,9 +154,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -178,9 +166,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -206,11 +192,8 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
@@ -235,9 +218,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -256,19 +237,14 @@ class CudaGraphManager:
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model(**model_inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -278,9 +254,8 @@ class CudaGraphManager:
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -412,51 +387,36 @@ def capture_graphs(
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
||||
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
||||
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
# HACK(woosuk): Special handling for DCP.
|
||||
if block_tables.cp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_batch.seq_lens,
|
||||
num_reqs,
|
||||
block_tables.cp_size,
|
||||
block_tables.cp_rank,
|
||||
block_tables.cp_interleave,
|
||||
)
|
||||
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -60,20 +59,13 @@ class InputBatch:
|
||||
query_start_loc_np: np.ndarray
|
||||
# [num_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
# [num_reqs]
|
||||
dcp_local_seq_lens: torch.Tensor | None
|
||||
|
||||
# [num_tokens_after_padding]
|
||||
input_ids: torch.Tensor
|
||||
# [num_tokens_after_padding]
|
||||
positions: torch.Tensor
|
||||
# [3, num_tokens_after_padding]
|
||||
mrope_positions: torch.Tensor | None
|
||||
# [num_tokens_after_padding, hidden_size]
|
||||
inputs_embeds: torch.Tensor | None
|
||||
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
# layer_name -> slot_mapping
|
||||
slot_mappings: dict[str, torch.Tensor]
|
||||
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
@@ -90,14 +82,16 @@ class InputBatch:
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
device = input_buffers.device
|
||||
|
||||
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = idx_mapping
|
||||
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
|
||||
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
assert int(num_scheduled_tokens.sum()) == num_tokens
|
||||
@@ -123,7 +117,6 @@ class InputBatch:
|
||||
input_ids = input_buffers.input_ids[:num_tokens].zero_()
|
||||
positions = input_buffers.positions[:num_tokens].zero_()
|
||||
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
|
||||
@@ -141,12 +134,9 @@ class InputBatch:
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=None,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=None, # type: ignore
|
||||
slot_mappings=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
@@ -507,6 +497,38 @@ def post_update(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _post_update_pool_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
query_start_loc_ptr,
|
||||
):
|
||||
batch_id = tl.program_id(0)
|
||||
query_start = tl.load(query_start_loc_ptr + batch_id)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_id)
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len)
|
||||
|
||||
|
||||
def post_update_pool(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_post_update_pool_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
query_start_loc,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _expand_idx_mapping_kernel(
|
||||
idx_mapping_ptr,
|
||||
|
||||
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
40
vllm/v1/worker/gpu/mm/encoder_cache.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
|
||||
|
||||
class EncoderCache:
|
||||
def __init__(self):
|
||||
# req_id -> MM features
|
||||
self.mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
# MM hash -> encoder outputs
|
||||
self.encoder_outputs: dict[str, torch.Tensor] = {}
|
||||
|
||||
def add_request(
|
||||
self, req_id: str, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> None:
|
||||
self.mm_features[req_id] = mm_features
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.mm_features.pop(req_id, None)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_outputs.clear()
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_outputs.pop(mm_hash, None)
|
||||
@@ -4,54 +4,32 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItem
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
|
||||
|
||||
|
||||
class EncoderRunner:
|
||||
def __init__(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
encoder_cache: EncoderCache,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.model = model
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.encoder_cache = encoder_cache
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
max_num_tokens, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""
|
||||
Clear the multi-modal cache that was used during profiling,
|
||||
but no longer needed during inference.
|
||||
"""
|
||||
# TODO: Implement MM budget for encoder dummy run
|
||||
pass
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
|
||||
self.req_id_to_mm_features[req_id] = mm_features
|
||||
|
||||
def free_encoder_cache(self, mm_hash: str) -> None:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
self.req_id_to_mm_features.pop(req_id, None)
|
||||
|
||||
def prepare_mm_inputs(
|
||||
self, scheduled_encoder_inputs: dict[str, list[int]]
|
||||
@@ -59,7 +37,7 @@ class EncoderRunner:
|
||||
mm_hashes: list[str] = []
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]] = []
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_feature = mm_features[mm_input_id]
|
||||
if mm_feature.data is None:
|
||||
@@ -72,25 +50,17 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_mm_encoder(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
mm_hashes: list[str],
|
||||
mm_kwargs: list[tuple[str, MultiModalKwargsItem]],
|
||||
) -> list[torch.Tensor]:
|
||||
if not mm_hashes:
|
||||
return []
|
||||
|
||||
encoder_outputs: list[torch.Tensor] = []
|
||||
for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs, device=self.device, pin_memory=False
|
||||
):
|
||||
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
|
||||
curr_group_outputs = self.model.embed_multimodal(**mm_kwargs_group)
|
||||
sanity_check_mm_encoder_outputs(
|
||||
curr_group_outputs, expected_num_items=num_items
|
||||
)
|
||||
encoder_outputs.extend(curr_group_outputs)
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.update(zip(mm_hashes, encoder_outputs))
|
||||
return encoder_outputs
|
||||
|
||||
def gather_mm_embeddings(
|
||||
@@ -122,7 +92,7 @@ class EncoderRunner:
|
||||
# OPTIMIZATION: Skip decode requests.
|
||||
continue
|
||||
|
||||
mm_features = self.req_id_to_mm_features[req_id]
|
||||
mm_features = self.encoder_cache.mm_features[req_id]
|
||||
for mm_feature in mm_features:
|
||||
pos_info = mm_feature.mm_position
|
||||
start_pos = pos_info.offset
|
||||
@@ -148,7 +118,7 @@ class EncoderRunner:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
encoder_output = self.encoder_cache.encoder_outputs.get(mm_hash, None)
|
||||
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
@@ -170,12 +140,11 @@ class EncoderRunner:
|
||||
@torch.inference_mode()
|
||||
def get_inputs_embeds(
|
||||
self,
|
||||
model: SupportsMultiModal,
|
||||
input_ids: torch.Tensor,
|
||||
mm_embeds: list[torch.Tensor],
|
||||
is_mm_embed: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = model.embed_input_ids(
|
||||
x = self.model.embed_input_ids(
|
||||
input_ids, multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
|
||||
)
|
||||
# Copy to the pre-allocated buffer for CUDA graphs.
|
||||
|
||||
@@ -38,15 +38,16 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
get_kv_cache_spec,
|
||||
init_attn_backend,
|
||||
@@ -56,10 +57,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import (
|
||||
get_cudagraph_and_dp_padding,
|
||||
make_num_tokens_across_dp,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -67,6 +65,7 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
expand_idx_mapping,
|
||||
get_num_sampled_and_rejected,
|
||||
post_update,
|
||||
post_update_pool,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
)
|
||||
@@ -76,8 +75,9 @@ from vllm.v1.worker.gpu.kv_connector import (
|
||||
get_kv_connector,
|
||||
)
|
||||
from vllm.v1.worker.gpu.lora_utils import LoraState
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.model_states import init_model_state
|
||||
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
|
||||
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
@@ -120,34 +120,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype
|
||||
]
|
||||
self.is_pooling_model = False
|
||||
|
||||
self.vocab_size = self.model_config.get_vocab_size()
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner = EncoderRunner(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_states = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
@@ -169,6 +146,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
# Multimodal
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
|
||||
self.model_config
|
||||
)
|
||||
self.encoder_cache = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
self.encoder_cache = EncoderCache()
|
||||
|
||||
self.speculator = None
|
||||
self.num_speculative_steps = 0
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
@@ -212,7 +198,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
)
|
||||
@@ -227,13 +212,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pooling models.
|
||||
self.is_pooling_model = self.model_config.runner_type == "pooling"
|
||||
self.pooling_runner: PoolingRunner | None = None
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
|
||||
@staticmethod
|
||||
def get_supported_tasks() -> tuple[str]:
|
||||
return ("generate",)
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.append("generate")
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
return tuple(tasks)
|
||||
|
||||
def load_model(self, *args, **kwargs) -> None:
|
||||
time_before_load = time.perf_counter()
|
||||
@@ -266,7 +262,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if self.speculator is not None:
|
||||
prepare_communication_buffer_for_model(self.speculator)
|
||||
prepare_communication_buffer_for_model(self.speculator.model)
|
||||
|
||||
# Initialize the components that require the model.
|
||||
self.model_state = init_model_state(
|
||||
self.vllm_config, self.model, self.encoder_cache, self.device
|
||||
)
|
||||
if self.is_pooling_model:
|
||||
self.pooling_runner = PoolingRunner(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
@@ -305,6 +308,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.model_state,
|
||||
self.kv_cache_config,
|
||||
self.attn_groups,
|
||||
self.block_tables,
|
||||
@@ -320,39 +324,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
|
||||
|
||||
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
max_query_len=input_batch.num_scheduled_tokens.max().item(),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
|
||||
)
|
||||
input_batch.attn_metadata = attn_metadata
|
||||
input_batch.slot_mappings = slot_mappings_by_layer
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
skip_attn: bool = True,
|
||||
uniform_decode: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
if uniform_decode:
|
||||
# Align tokens to uniform_decode_query_len for cudagraph
|
||||
# compatibility across DP ranks.
|
||||
query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
|
||||
num_tokens = num_reqs * query_len
|
||||
num_tokens_per_request = [query_len] * num_reqs
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
@@ -387,7 +379,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
# dummy run the eagle speculator's propose to ensure DP/EP sync.
|
||||
if self.speculator is not None:
|
||||
self.speculator.propose(
|
||||
input_batch=input_batch,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=torch.ones(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
num_rejected=torch.zeros(
|
||||
input_batch.num_reqs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
last_sampled=self.req_states.last_sampled_tokens,
|
||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
)
|
||||
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -416,39 +442,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
assert self.pooling_runner is not None
|
||||
self.pooling_runner.dummy_pooler_run(hidden_states)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens, skip_attn=True
|
||||
)
|
||||
|
||||
# Only run sampler on last PP rank (non-last ranks return None).
|
||||
# Only run sampler/pooler on last PP rank (non-last ranks return None).
|
||||
if self.is_last_pp_rank:
|
||||
assert sample_hidden_states is not None
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
slot_mappings=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
if self.pooling_runner is None:
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
else:
|
||||
self._dummy_pooler_run(hidden_states)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_mm_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_mm_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.reset_encoder_cache()
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.reset_encoder_cache()
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
# SP is not supported yet.
|
||||
@@ -477,17 +500,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
model_state=self.model_state,
|
||||
input_buffers=self.input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=self.block_tables,
|
||||
attn_groups=self.attn_groups,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
@@ -522,21 +538,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
finished_req_ids = finished_req_ids.union(preempted_req_ids)
|
||||
for req_id in finished_req_ids:
|
||||
self.req_states.remove_request(req_id)
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.remove_request(req_id)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.remove_request(req_id)
|
||||
self.prompt_logprobs_worker.remove_request(req_id)
|
||||
self.lora_state.remove_request(req_id)
|
||||
|
||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
if self.supports_mm_inputs:
|
||||
if self.encoder_cache is not None:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_runner.free_encoder_cache(mm_hash)
|
||||
self.encoder_cache.free_encoder_cache(mm_hash)
|
||||
|
||||
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
assert new_req_data.prompt_token_ids is not None
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
assert new_req_data.sampling_params is not None
|
||||
req_id = new_req_data.req_id
|
||||
prompt_len = len(new_req_data.prompt_token_ids)
|
||||
self.req_states.add_request(
|
||||
@@ -547,34 +562,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
self.encoder_runner.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
if self.encoder_cache is not None:
|
||||
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
|
||||
|
||||
self.model_state.add_request(req_index, new_req_data)
|
||||
self.block_tables.append_block_ids(
|
||||
req_index, new_req_data.block_ids, overwrite=True
|
||||
)
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
|
||||
|
||||
if new_req_data.sampling_params is not None:
|
||||
self.sampler.add_request(
|
||||
req_index, prompt_len, new_req_data.sampling_params
|
||||
)
|
||||
self.prompt_logprobs_worker.add_request(
|
||||
req_id, req_index, new_req_data.sampling_params
|
||||
)
|
||||
|
||||
if scheduler_output.scheduled_new_reqs:
|
||||
self.req_states.apply_staged_writes()
|
||||
self.sampler.apply_staged_writes()
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.apply_staged_writes()
|
||||
self.model_state.apply_staged_writes()
|
||||
|
||||
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# Add new blocks for the existing requests.
|
||||
@@ -637,9 +645,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
@@ -648,11 +653,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
@@ -676,6 +678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
dcp_local_seq_lens = None
|
||||
if self.use_dcp:
|
||||
# Prepare dcp local seq_lens.
|
||||
prepare_dcp_local_seq_lens(
|
||||
@@ -686,16 +689,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -711,39 +705,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.input_buffers.positions[:num_tokens],
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# Layer name -> attention metadata.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@@ -758,37 +719,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
|
||||
positions=self.input_buffers.positions[:num_tokens_after_padding],
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
def prepare_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
)
|
||||
self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
self.req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def prepare_dummy_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
|
||||
slot_mappings = self.block_tables.get_dummy_slot_mappings(
|
||||
input_batch.num_tokens
|
||||
)
|
||||
return mm_embeds, is_mm_embed
|
||||
return block_tables, slot_mappings
|
||||
|
||||
def sample(
|
||||
self,
|
||||
@@ -926,6 +886,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
block_tables, slot_mappings = self.prepare_attn(input_batch)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
lora_inputs = self.lora_state.make_lora_inputs(
|
||||
@@ -934,35 +896,61 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_batch.num_scheduled_tokens,
|
||||
)
|
||||
self._set_active_loras(*lora_inputs)
|
||||
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
mm_embeds, is_mm_embed = self.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs, input_batch
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
self.model, input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
input_batch.inputs_embeds = inputs_embeds[
|
||||
: input_batch.num_tokens_after_padding
|
||||
]
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP or memory profiling.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
num_reqs, num_tokens_after_padding, self.input_buffers
|
||||
)
|
||||
if self.uses_mrope:
|
||||
input_batch.mrope_positions = self.mrope_states.mrope_positions[
|
||||
:, :num_tokens_after_padding
|
||||
]
|
||||
if not skip_attn_for_dummy_run:
|
||||
self.prepare_dummy_attn_metadata(input_batch)
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
else:
|
||||
block_tables = None
|
||||
slot_mappings = None
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
attn_metadata = None
|
||||
slot_mappings_by_layer = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
assert slot_mappings is not None
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
assert block_tables is not None
|
||||
attn_metadata = self.model_state.prepare_attn(
|
||||
input_batch,
|
||||
block_tables,
|
||||
slot_mappings,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
inputs_embeds = None
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
# Run MM encoder (if needed) and get multimodal embeddings.
|
||||
# Only first PP rank prepares multimodal embeddings.
|
||||
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
|
||||
# to obtain inputs_embeds, because the compiled model expects this input.
|
||||
inputs_embeds = self.model_state.get_mm_embeddings(
|
||||
scheduler_output.scheduled_encoder_inputs,
|
||||
input_batch,
|
||||
self.req_states,
|
||||
)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_batch.input_ids,
|
||||
"positions": input_batch.positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
# NOTE: Values returned by `prepare_inputs` will override the default
|
||||
# values above.
|
||||
**self.model_state.prepare_inputs(input_batch, self.req_states),
|
||||
}
|
||||
if not self.is_first_pp_rank:
|
||||
# Update for non-first PP ranks.
|
||||
model_inputs["input_ids"] = None
|
||||
model_inputs["inputs_embeds"] = None
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
@@ -979,41 +967,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
positions = input_batch.mrope_positions
|
||||
|
||||
if self.is_first_pp_rank:
|
||||
input_ids = input_batch.input_ids
|
||||
inputs_embeds = input_batch.inputs_embeds
|
||||
assert intermediate_tensors is None
|
||||
else:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
model_output = self.model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -1021,33 +990,44 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
self.execute_model_state = (
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.execute_model_state = (None, None, input_batch, kv_connector_output)
|
||||
return hidden_states
|
||||
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
self.execute_model_state = (
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
) # type: ignore
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput | None:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None # type: ignore
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
(
|
||||
input_batch,
|
||||
model_inputs,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
kv_connector_output,
|
||||
num_tokens_across_dp,
|
||||
) = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
@@ -1109,6 +1089,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.speculator is not None:
|
||||
draft_tokens = self.speculator.propose(
|
||||
input_batch,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
@@ -1117,6 +1099,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
@@ -1127,3 +1110,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.draft_tokens_handler.get_draft_tokens()
|
||||
|
||||
@torch.inference_mode()
|
||||
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
|
||||
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
self.postprocess_pool(input_batch)
|
||||
return None
|
||||
|
||||
assert self.pooling_runner is not None
|
||||
pooler_output, is_valid = self.pooling_runner.pool(
|
||||
hidden_states, input_batch, self.req_states
|
||||
)
|
||||
self.postprocess_pool(input_batch)
|
||||
|
||||
# Build the model runner output.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
async_output = AsyncPoolingOutput(
|
||||
model_runner_output=model_runner_output,
|
||||
pooler_output=pooler_output,
|
||||
is_valid=is_valid,
|
||||
main_stream=self.main_stream,
|
||||
copy_stream=self.output_copy_stream,
|
||||
copy_event=self.output_copy_event,
|
||||
)
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def postprocess_pool(self, input_batch: InputBatch) -> None:
|
||||
# Update the number of computed tokens.
|
||||
post_update_pool(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
input_batch.query_start_loc,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
|
||||
np.minimum(
|
||||
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
|
||||
)
|
||||
|
||||
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
18
vllm/v1/worker/gpu/model_states/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
|
||||
|
||||
def init_model_state(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
|
||||
return DefaultModelState(vllm_config, model, encoder_cache, device)
|
||||
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
161
vllm/v1/worker/gpu/model_states/default.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class DefaultModelState(ModelState):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.supports_mm_inputs = encoder_cache is not None
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
assert encoder_cache is not None
|
||||
self.encoder_cache = encoder_cache
|
||||
self.encoder_runner = EncoderRunner(
|
||||
model=self.model,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.inputs_embeds_size,
|
||||
encoder_cache=encoder_cache,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.uses_mrope = self.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
self.mrope_state = MRopeState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
if self.uses_mrope:
|
||||
# Pre-compute M-RoPE positions for prefill.
|
||||
assert new_req_data.prefill_token_ids is not None
|
||||
self.mrope_state.init_prefill_mrope_positions(
|
||||
req_index,
|
||||
self.model, # type: ignore
|
||||
new_req_data.prefill_token_ids,
|
||||
mm_features=new_req_data.mm_features,
|
||||
)
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
if self.uses_mrope:
|
||||
self.mrope_state.apply_staged_writes()
|
||||
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
|
||||
scheduled_encoder_inputs
|
||||
)
|
||||
if mm_kwargs:
|
||||
# Execute the multimodal encoder.
|
||||
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
||||
# Cache the encoder outputs by mm_hash
|
||||
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
|
||||
|
||||
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
|
||||
input_batch.req_ids,
|
||||
input_batch.num_tokens,
|
||||
input_batch.num_scheduled_tokens,
|
||||
input_batch.query_start_loc_np,
|
||||
req_states.prefill_len.np[input_batch.idx_mapping_np],
|
||||
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
|
||||
)
|
||||
inputs_embeds = self.encoder_runner.get_inputs_embeds(
|
||||
input_batch.input_ids, mm_embeds, is_mm_embed
|
||||
)
|
||||
return inputs_embeds[: input_batch.num_tokens_after_padding]
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
self.mrope_state.prepare_mrope_positions(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
req_states.prefill_len.gpu,
|
||||
req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
mrope_positions = self.mrope_state.mrope_positions[
|
||||
:, : input_batch.num_tokens_after_padding
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
model_inputs = {}
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
|
||||
model_inputs["positions"] = mrope_positions
|
||||
return model_inputs
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
67
vllm/v1/worker/gpu/model_states/interface.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class ModelState(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_staged_writes(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
0
vllm/v1/worker/gpu/pool/__init__.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
45
vllm/v1/worker/gpu/pool/pooling_runner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.models import VllmModelForPooling, is_pooling_model
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
|
||||
# on decoder-only models. How to support other pooling tasks and models
|
||||
# is to be determined.
|
||||
class PoolingRunner:
|
||||
def __init__(self, model: nn.Module):
|
||||
self.model = cast(VllmModelForPooling, model)
|
||||
|
||||
def get_supported_pooling_tasks(self) -> list[PoolingTask]:
|
||||
if not is_pooling_model(self.model):
|
||||
return []
|
||||
assert "embed" in self.model.pooler.get_supported_tasks()
|
||||
return ["embed"]
|
||||
|
||||
def pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO(woosuk): Support different types of pooling tasks.
|
||||
last_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
# TODO(woosuk): Make normalization optional.
|
||||
last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1)
|
||||
|
||||
prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping]
|
||||
is_valid = input_batch.seq_lens == prompt_len
|
||||
return last_hidden_states, is_valid
|
||||
|
||||
def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
|
||||
F.normalize(hidden_states, p=2, dim=-1)
|
||||
return
|
||||
@@ -72,7 +72,7 @@ class BadWordsState:
|
||||
def apply_bad_words(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -84,7 +84,7 @@ class BadWordsState:
|
||||
|
||||
apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.bad_word_token_ids.gpu,
|
||||
self.bad_word_offsets.gpu,
|
||||
self.num_bad_words.gpu,
|
||||
@@ -114,17 +114,17 @@ def _bad_words_kernel(
|
||||
input_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
bw_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
|
||||
|
||||
if bw_idx >= num_bad_words:
|
||||
return
|
||||
|
||||
pos = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
cur_req_first_pos = logit_idx - pos
|
||||
pos = tl.load(expanded_local_pos_ptr + token_idx)
|
||||
cur_req_first_pos = token_idx - pos
|
||||
|
||||
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
|
||||
total_len = tl.load(total_len_ptr + req_state_idx)
|
||||
@@ -159,7 +159,7 @@ def _bad_words_kernel(
|
||||
match = match & (expected == actual)
|
||||
|
||||
if match:
|
||||
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
|
||||
tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
@@ -175,8 +175,8 @@ def apply_bad_words(
|
||||
expanded_local_pos: torch.Tensor,
|
||||
max_num_bad_words: int,
|
||||
) -> None:
|
||||
total_num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
|
||||
num_tokens = logits.shape[0]
|
||||
_bad_words_kernel[(num_tokens, max_num_bad_words)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
expanded_idx_mapping,
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
temperature_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
|
||||
if temperature == 0.0 or temperature == 1.0:
|
||||
# Early return to avoid loading logits.
|
||||
@@ -25,24 +25,24 @@ def _temperature_kernel(
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits / temperature
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_temperature(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
temperature: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_temperature_kernel[(num_reqs, num_blocks)](
|
||||
_temperature_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + block,
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
expanded_idx_mapping: torch.Tensor, # [num_tokens]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_tokens]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
|
||||
@@ -121,7 +121,7 @@ class LogitBiasState:
|
||||
def apply_logit_bias(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -131,7 +131,7 @@ class LogitBiasState:
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
pos,
|
||||
self.num_allowed_token_ids.gpu,
|
||||
self.allowed_token_ids.gpu,
|
||||
@@ -149,7 +149,7 @@ def _bias_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
vocab_size,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
# Allowed token IDs.
|
||||
num_allowed_token_ids_ptr,
|
||||
allowed_token_ids_ptr,
|
||||
@@ -169,8 +169,8 @@ def _bias_kernel(
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
LOGITS_BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -186,21 +186,21 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
logits = tl.load(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
|
||||
)
|
||||
|
||||
# Set logits to -inf for all tokens.
|
||||
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + offset,
|
||||
logits_ptr + token_idx * logits_stride + offset,
|
||||
-float("inf"),
|
||||
mask=offset < vocab_size,
|
||||
)
|
||||
|
||||
# Restore logits for allowed token IDs.
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + allowed_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + allowed_token_ids,
|
||||
logits,
|
||||
mask=mask,
|
||||
)
|
||||
@@ -214,13 +214,13 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask)
|
||||
logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
|
||||
logits += bias
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
|
||||
|
||||
# Apply min tokens.
|
||||
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
min_len = tl.load(min_lens_ptr + req_state_idx)
|
||||
if num_stop_token_ids > 0 and pos < min_len:
|
||||
mask = block < num_stop_token_ids
|
||||
@@ -229,7 +229,7 @@ def _bias_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
logits_ptr + batch_idx * logits_stride + stop_token_ids,
|
||||
logits_ptr + token_idx * logits_stride + stop_token_ids,
|
||||
-float("inf"),
|
||||
mask=mask,
|
||||
)
|
||||
@@ -237,7 +237,7 @@ def _bias_kernel(
|
||||
|
||||
def apply_logit_bias(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
num_allowed_token_ids: torch.Tensor,
|
||||
allowed_token_ids: torch.Tensor,
|
||||
@@ -248,7 +248,7 @@ def apply_logit_bias(
|
||||
num_stop_token_ids: torch.Tensor,
|
||||
stop_token_ids: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = triton.next_power_of_2(
|
||||
max(
|
||||
allowed_token_ids.shape[-1],
|
||||
@@ -257,11 +257,11 @@ def apply_logit_bias(
|
||||
)
|
||||
)
|
||||
LOGITS_BLOCK_SIZE = 8192
|
||||
_bias_kernel[(num_reqs,)](
|
||||
_bias_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
vocab_size,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
num_allowed_token_ids,
|
||||
allowed_token_ids,
|
||||
allowed_token_ids.stride(0),
|
||||
|
||||
@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
|
||||
def _min_p_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
min_p_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
|
||||
if min_p == 0.0:
|
||||
return
|
||||
@@ -25,7 +25,9 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_val = tl.max(tl.maximum(logits, max_val))
|
||||
max_val = max_val.to(tl.float32) # type: ignore
|
||||
@@ -35,21 +37,23 @@ def _min_p_kernel(
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf")
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = tl.where(logits < threshold, float("-inf"), logits)
|
||||
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
|
||||
tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_min_p(
|
||||
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
|
||||
) -> None:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
_min_p_kernel[(num_reqs,)](
|
||||
_min_p_kernel[(num_tokens,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
min_p,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
|
||||
@@ -82,7 +82,7 @@ class PenaltiesState:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
input_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
@@ -94,7 +94,7 @@ class PenaltiesState:
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
self.repetition_penalty.gpu,
|
||||
@@ -110,7 +110,7 @@ class PenaltiesState:
|
||||
def _penalties_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
token_ids_ptr,
|
||||
expanded_local_pos_ptr,
|
||||
repetition_penalty_ptr,
|
||||
@@ -125,7 +125,7 @@ def _penalties_kernel(
|
||||
MAX_SPEC_LEN: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + token_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
|
||||
@@ -191,7 +191,7 @@ def _penalties_kernel(
|
||||
|
||||
def apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
token_ids: torch.Tensor,
|
||||
expanded_local_pos: torch.Tensor,
|
||||
repetition_penalty: torch.Tensor,
|
||||
@@ -207,7 +207,7 @@ def apply_penalties(
|
||||
_penalties_kernel[(num_tokens, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
token_ids,
|
||||
expanded_local_pos,
|
||||
repetition_penalty,
|
||||
@@ -225,7 +225,7 @@ def apply_penalties(
|
||||
|
||||
@triton.jit
|
||||
def _bincount_kernel(
|
||||
idx_mapping_ptr,
|
||||
expanded_idx_mapping_ptr,
|
||||
all_token_ids_ptr,
|
||||
all_token_ids_stride,
|
||||
prompt_len_ptr,
|
||||
@@ -236,9 +236,9 @@ def _bincount_kernel(
|
||||
output_bin_counts_stride,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||
@@ -276,7 +276,7 @@ def _bincount_kernel(
|
||||
|
||||
|
||||
def bincount(
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
all_token_ids: torch.Tensor,
|
||||
prompt_len: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
@@ -284,13 +284,13 @@ def bincount(
|
||||
output_bin_counts: torch.Tensor,
|
||||
max_prefill_len: int,
|
||||
) -> None:
|
||||
prompt_bin_mask[idx_mapping] = 0
|
||||
output_bin_counts[idx_mapping] = 0
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
prompt_bin_mask[expanded_idx_mapping] = 0
|
||||
output_bin_counts[expanded_idx_mapping] = 0
|
||||
num_tokens = expanded_idx_mapping.shape[0]
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
|
||||
_bincount_kernel[(num_reqs, num_blocks)](
|
||||
idx_mapping,
|
||||
_bincount_kernel[(num_tokens, num_blocks)](
|
||||
expanded_idx_mapping,
|
||||
all_token_ids,
|
||||
all_token_ids.stride(0),
|
||||
prompt_len,
|
||||
|
||||
@@ -56,7 +56,7 @@ class Sampler:
|
||||
def __call__(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
cu_num_logits_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
@@ -68,7 +68,7 @@ class Sampler:
|
||||
num_nans = get_num_nans(logits) if self.compute_nans else None
|
||||
sampled, processed_logits = self.sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
pos,
|
||||
input_ids,
|
||||
@@ -101,7 +101,7 @@ class Sampler:
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -111,12 +111,14 @@ class Sampler:
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
self.logit_bias_state.apply_logit_bias(
|
||||
logits, expanded_idx_mapping, idx_mapping_np, pos
|
||||
)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
@@ -126,27 +128,29 @@ class Sampler:
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_temperature(
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(
|
||||
logits, idx_mapping, idx_mapping_np
|
||||
logits, expanded_idx_mapping, idx_mapping_np
|
||||
)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
idx_mapping,
|
||||
expanded_idx_mapping,
|
||||
self.sampling_states.temperature.gpu,
|
||||
self.sampling_states.seeds.gpu,
|
||||
pos,
|
||||
|
||||
@@ -64,7 +64,7 @@ class SamplingStates:
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
temp_np = self.temperature.np[idx_mapping_np]
|
||||
@@ -72,23 +72,23 @@ class SamplingStates:
|
||||
# No request requires temperature. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_temperature(logits, idx_mapping, self.temperature.gpu)
|
||||
apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
|
||||
# No request uses min_p. Skip the kernel launch.
|
||||
return
|
||||
apply_min_p(logits, idx_mapping, self.min_p.gpu)
|
||||
apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
|
||||
|
||||
def apply_top_k_top_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
|
||||
@@ -96,8 +96,8 @@ class SamplingStates:
|
||||
if not (do_top_k or do_top_p):
|
||||
return logits
|
||||
|
||||
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
|
||||
top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -54,11 +55,32 @@ class EagleCudaGraphManager:
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_tokens: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -76,12 +98,11 @@ class EagleCudaGraphManager:
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -158,6 +179,7 @@ class EagleCudaGraphManager:
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
@@ -173,6 +195,7 @@ class EagleCudaGraphManager:
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
|
||||
@@ -16,7 +16,9 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
@@ -44,10 +46,13 @@ class EagleSpeculator:
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.inputs_embeds_size = self.draft_model_config.get_inputs_embeds_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# DP configuration
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
@@ -77,10 +82,12 @@ class EagleSpeculator:
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
model_state: ModelState,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.model_state = model_state
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_groups = attn_groups
|
||||
self.block_tables = block_tables
|
||||
@@ -120,8 +127,8 @@ class EagleSpeculator:
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
@@ -162,9 +169,10 @@ class EagleSpeculator:
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
if attn_metadata is not None:
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
@@ -172,6 +180,7 @@ class EagleSpeculator:
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
@@ -182,6 +191,8 @@ class EagleSpeculator:
|
||||
def propose(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
# [num_tokens, hidden_size]
|
||||
last_hidden_states: torch.Tensor,
|
||||
# num_layers x [num_tokens, hidden_size]
|
||||
@@ -198,6 +209,9 @@ class EagleSpeculator:
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
|
||||
# number of rejected tokens, we maintain the size of eagle's input_ids and
|
||||
@@ -229,9 +243,9 @@ class EagleSpeculator:
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
input_batch.slot_mappings,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
@@ -277,48 +291,64 @@ class EagleSpeculator:
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
|
||||
cudagraph_mode, cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
|
||||
)
|
||||
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
num_reqs,
|
||||
cudagraph_size,
|
||||
cudagraph_mode.value,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_updated = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
105
vllm/v1/worker/gpu/warmup.py
Normal file
105
vllm/v1/worker/gpu/warmup.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import PoolingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import (
|
||||
CachedRequestData,
|
||||
GrammarOutput,
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def warmup_kernels(model_runner: GPUModelRunner) -> None:
|
||||
"""Run two execute_model + sample_tokens iterations to JIT compile
|
||||
triton kernels.
|
||||
|
||||
The first iteration simulates a prefill with requests of 2 prompt
|
||||
tokens each. The second iteration simulates a decode step with all
|
||||
requests generating 1 token each.
|
||||
"""
|
||||
prompt_token_ids = [0, 1]
|
||||
prompt_len = len(prompt_token_ids)
|
||||
num_reqs = min(
|
||||
model_runner.scheduler_config.max_num_seqs,
|
||||
model_runner.scheduler_config.max_num_batched_tokens // prompt_len,
|
||||
)
|
||||
|
||||
num_kv_cache_groups = len(model_runner.kv_cache_config.kv_cache_groups)
|
||||
req_ids = [f"_warmup_{i}_" for i in range(num_reqs)]
|
||||
|
||||
# SamplingParams exercising all sampling features.
|
||||
if model_runner.is_pooling_model:
|
||||
sampling_params = None
|
||||
pooling_params = PoolingParams()
|
||||
else:
|
||||
sampling_params = SamplingParams.for_sampler_warmup()
|
||||
pooling_params = None
|
||||
|
||||
# Step 1: Prefill all requests with 2 prompt tokens each.
|
||||
new_reqs = [
|
||||
NewRequestData.from_request(
|
||||
Request(req_ids[i], prompt_token_ids, sampling_params, pooling_params),
|
||||
# Each request uses a distinct block per KV cache group.
|
||||
block_ids=tuple([i] for _ in range(num_kv_cache_groups)),
|
||||
prefill_token_ids=prompt_token_ids,
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
|
||||
prefill_output = SchedulerOutput.make_empty()
|
||||
prefill_output.scheduled_new_reqs = new_reqs
|
||||
prefill_output.num_scheduled_tokens = {rid: prompt_len for rid in req_ids}
|
||||
prefill_output.total_num_scheduled_tokens = prompt_len * num_reqs
|
||||
prefill_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
# Disable KV connector for warmup run.
|
||||
model_runner.kv_connector.set_disabled(True)
|
||||
model_runner.execute_model(prefill_output)
|
||||
|
||||
if not model_runner.is_pooling_model:
|
||||
# Warm up sampler and perform a decode step for non-pooling models.
|
||||
|
||||
grammar_output = None
|
||||
if model_runner.is_last_pp_rank:
|
||||
# Build a GrammarOutput to exercise the structured output bitmask
|
||||
# kernel during the prefill step.
|
||||
vocab_size = model_runner.model_config.get_vocab_size()
|
||||
bitmask_width = (vocab_size + 31) // 32
|
||||
grammar_bitmask = np.full(
|
||||
(len(req_ids), bitmask_width), fill_value=-1, dtype=np.int32
|
||||
)
|
||||
grammar_output = GrammarOutput(
|
||||
structured_output_request_ids=req_ids, grammar_bitmask=grammar_bitmask
|
||||
)
|
||||
|
||||
model_runner.sample_tokens(grammar_output)
|
||||
|
||||
# Step 2: Decode all requests with 1 token each.
|
||||
cached_req_data = CachedRequestData.make_empty()
|
||||
cached_req_data.req_ids = list(req_ids)
|
||||
cached_req_data.new_block_ids = [None] * num_reqs
|
||||
cached_req_data.num_computed_tokens = [prompt_len] * num_reqs
|
||||
cached_req_data.num_output_tokens = [1] * num_reqs
|
||||
|
||||
decode_output = SchedulerOutput.make_empty()
|
||||
decode_output.scheduled_cached_reqs = cached_req_data
|
||||
decode_output.num_scheduled_tokens = {rid: 1 for rid in req_ids}
|
||||
decode_output.total_num_scheduled_tokens = num_reqs
|
||||
decode_output.num_common_prefix_blocks = [0] * num_kv_cache_groups
|
||||
|
||||
model_runner.execute_model(decode_output)
|
||||
model_runner.sample_tokens(None)
|
||||
|
||||
# Clean up - process finish_req_ids.
|
||||
cleanup_output = SchedulerOutput.make_empty()
|
||||
cleanup_output.finished_req_ids = set(req_ids)
|
||||
model_runner.execute_model(cleanup_output)
|
||||
model_runner.kv_connector.set_disabled(False)
|
||||
torch.cuda.synchronize()
|
||||
Reference in New Issue
Block a user