Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
@@ -15,13 +14,12 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
@@ -29,13 +27,11 @@ class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
@@ -88,9 +84,8 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -113,24 +108,23 @@ class CudaGraphManager:
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
assert mrope_positions is not None
|
||||
positions = mrope_positions[:, :num_tokens]
|
||||
if inputs_embeds is not None:
|
||||
inputs_embeds = inputs_embeds[:num_tokens]
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -143,11 +137,7 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -164,9 +154,7 @@ class CudaGraphManager:
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -178,9 +166,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -206,11 +192,8 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
@@ -235,9 +218,7 @@ class CudaGraphManager:
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
@@ -256,19 +237,14 @@ class CudaGraphManager:
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
model(**model_inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
@@ -278,9 +254,8 @@ class CudaGraphManager:
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -412,51 +387,36 @@ def capture_graphs(
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
||||
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
||||
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
||||
slot_mappings, kv_cache_config
|
||||
)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens,
|
||||
# HACK(woosuk): Special handling for DCP.
|
||||
if block_tables.cp_size > 1:
|
||||
prepare_dcp_local_seq_lens(
|
||||
input_buffers.dcp_local_seq_lens,
|
||||
input_batch.seq_lens,
|
||||
num_reqs,
|
||||
block_tables.cp_size,
|
||||
block_tables.cp_rank,
|
||||
block_tables.cp_interleave,
|
||||
)
|
||||
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata = model_state.prepare_attn(
|
||||
input_batch,
|
||||
input_block_tables,
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
Reference in New Issue
Block a user