adapt to main2main for model runner v2 (#7578)
### What this PR does / why we need it?
This PR aims to adapt to newest commit of vllm main branch for model
runner v2. please refer to
https://github.com/vllm-project/vllm-ascend/issues/5208
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -16,128 +16,68 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import BatchExecutionDescriptor, ModelCudaGraphManager
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
|
||||
class AclGraphManager(CudaGraphManager):
|
||||
"""ACL Graph Manager for Ascend NPUs."""
|
||||
class ModelAclGraphManager(ModelCudaGraphManager):
|
||||
"""ACL Model Cuda Graph Manager for Ascend NPUs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
decode_query_len: int,
|
||||
model_runner: Any,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
device,
|
||||
cudagraph_mode,
|
||||
decode_query_len,
|
||||
)
|
||||
# set model runner attribute, so we can access attributes model runner
|
||||
# when call `run_fullgraph` method in CudaGraphManager,
|
||||
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
|
||||
self.model_runner = model_runner
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
use_aux_hidden_state_outputs,
|
||||
device,
|
||||
)
|
||||
# capture_sizes sorts in ascending order.
|
||||
self.capture_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
||||
# vllm-ascend need to update graph params of attention backend.
|
||||
# so we need to set graph params before capture full graph.
|
||||
if super().needs_capture():
|
||||
set_graph_params(self.cudagraph_sizes)
|
||||
set_graph_params(self.capture_sizes)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
"""Override _capture_full_graph because we need to set capturing=True in forward context."""
|
||||
# set capturing=True in before model forward.
|
||||
model = ModelWithContext(model)
|
||||
return super()._capture_full_graph(
|
||||
num_tokens,
|
||||
num_reqs,
|
||||
model,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
num_tokens_across_dp,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
has_lora,
|
||||
)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> None:
|
||||
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
|
||||
super().capture_graph(
|
||||
num_tokens,
|
||||
capture_cg_mode,
|
||||
model,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
has_lora,
|
||||
uniform_decode,
|
||||
)
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Override run_fullgraph to update full graph params in run_fullgraph."""
|
||||
num_tokens = desc.num_tokens
|
||||
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
|
||||
ret = super().run_fullgraph(num_tokens)
|
||||
assert self.model_runner.cudagraph_and_dp_padding is not None
|
||||
ret = super().run_fullgraph(desc)
|
||||
|
||||
positions = self.model_runner.input_buffers.positions[:num_tokens]
|
||||
_num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
self.model_runner.cudagraph_and_dp_padding
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
|
||||
# refer to vllm.v1.worker.gpu.dp_utils.sync_cudagraph_and_dp_padding to
|
||||
# calculate num_tokens_across_dp.
|
||||
num_tokens_across_dp = torch.full([self.model_runner.dp_size], num_tokens, device=self.device)
|
||||
with set_forward_context(
|
||||
self.model_runner.input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
cudagraph_runtime_mode=desc.cg_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=None, # Full graph model don't need batch_descriptor
|
||||
slot_mapping=self.model_runner.input_batch.slot_mappings,
|
||||
@@ -155,79 +95,31 @@ class AclGraphManager(CudaGraphManager):
|
||||
)
|
||||
return ret
|
||||
|
||||
def is_uniform_decode(
|
||||
def capture(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
):
|
||||
return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def prepare_capture_inputs_wrapper():
|
||||
"""Context manager to override input preparation for NPU graph capture."""
|
||||
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
|
||||
# in CudaGraphManager.
|
||||
ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture
|
||||
try:
|
||||
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture
|
||||
yield
|
||||
finally:
|
||||
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori
|
||||
|
||||
|
||||
def prepare_inputs_to_capture(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
input_buffers.seq_lens_cpu[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens_cpu[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
seq_lens_np=input_buffers.seq_lens_np,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
use_aux_hidden_state_outputs: bool = False,
|
||||
progress_bar_desc: str = "Capturing CUDA graphs",
|
||||
) -> None:
|
||||
"""Capture CUDA graphs for model forward pass."""
|
||||
model = ModelWithContext(model)
|
||||
return super().capture(
|
||||
model,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
has_lora,
|
||||
use_aux_hidden_state_outputs,
|
||||
progress_bar_desc,
|
||||
)
|
||||
|
||||
|
||||
class ModelWithContext(nn.Module):
|
||||
@@ -242,6 +134,7 @@ class ModelWithContext(nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
# In warmup phase, capturing=False by default.
|
||||
# when capturing, we need to set capturing=True in forward context.
|
||||
_EXTRA_CTX.capturing = True
|
||||
if torch.npu.is_current_stream_capturing():
|
||||
_EXTRA_CTX.capturing = True
|
||||
|
||||
return self.original_model(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user