[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -5,4 +5,5 @@ This directory contains the new model runner which is under active development.
|
||||
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
|
||||
to get specific plans.
|
||||
|
||||
supported vllm version: main@1339784
|
||||
supported vllm version: main@4034c3d32e30d01639459edd3ab486f56993876d
|
||||
related PR: <https://github.com/vllm-project/vllm-ascend/pull/7110>
|
||||
|
||||
@@ -19,16 +19,25 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.compilation.acl_graph import set_graph_params, update_full_graph_params
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
|
||||
@@ -38,44 +47,134 @@ class AclGraphManager(CudaGraphManager):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
model_runner: Any, # NPUModelRunner type, in case circular import, so we pass it as Any
|
||||
):
|
||||
with torch_cuda_wrapper():
|
||||
super().__init__(vllm_config, use_mrope, device)
|
||||
# set model runner attribute, so we can access attributes model runner
|
||||
# when call `run_fullgraph` method in CudaGraphManager,
|
||||
# then we don't need to # copy `execute_model` method in `NPUModelRunner` class.
|
||||
self.model_runner = model_runner
|
||||
super().__init__(
|
||||
vllm_config,
|
||||
use_aux_hidden_state_outputs,
|
||||
device,
|
||||
)
|
||||
# vllm-ascend need to update graph params of attention backend.
|
||||
# so we need to set graph params before capture full graph.
|
||||
if super().needs_capture():
|
||||
set_graph_params(self.cudagraph_sizes)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
"""Override _capture_full_graph because we need to set capturing=True in forward context."""
|
||||
# set capturing=True in before model forward.
|
||||
model = ModelWithContext(model)
|
||||
return super()._capture_full_graph(
|
||||
num_tokens,
|
||||
num_reqs,
|
||||
model,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
num_tokens_across_dp,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
has_lora,
|
||||
)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> None:
|
||||
with torch_cuda_wrapper(), prepare_capture_inputs_wrapper():
|
||||
super().capture_graph(
|
||||
num_tokens,
|
||||
capture_cg_mode,
|
||||
model,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
has_lora,
|
||||
uniform_decode,
|
||||
)
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Override run_fullgraph to update full graph params in run_fullgraph."""
|
||||
logger.info_once(f"run_fullgraph with num_tokens={num_tokens}")
|
||||
ret = super().run_fullgraph(num_tokens)
|
||||
assert self.model_runner.cudagraph_and_dp_padding is not None
|
||||
|
||||
positions = self.model_runner.input_buffers.positions[:num_tokens]
|
||||
_num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
self.model_runner.cudagraph_and_dp_padding
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
|
||||
with set_forward_context(
|
||||
self.model_runner.input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=None, # Full graph model don't need batch_descriptor
|
||||
slot_mapping=self.model_runner.input_batch.slot_mappings,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
update_full_graph_params(
|
||||
# FIXME(Ronald1995): support hybrid attn backend
|
||||
list(self.model_runner.attn_backends.values())[0],
|
||||
self.model_runner.update_stream,
|
||||
forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config,
|
||||
self.model_runner.speculative_config,
|
||||
positions.shape[0],
|
||||
)
|
||||
return ret
|
||||
|
||||
def is_uniform_decode(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
):
|
||||
return (max_query_len == self.uniform_decode_query_len) and (num_tokens == max_query_len * num_reqs)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def prepare_capture_inputs_wrapper():
|
||||
"""Context manager to override input preparation for NPU graph capture."""
|
||||
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
|
||||
# in CudaGraphManager.
|
||||
global prepare_inputs_to_capture_gpu
|
||||
ori = vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture
|
||||
try:
|
||||
ori_func = prepare_inputs_to_capture_gpu
|
||||
prepare_inputs_to_capture_gpu = prepare_inputs_to_capture
|
||||
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = prepare_inputs_to_capture
|
||||
yield
|
||||
finally:
|
||||
prepare_inputs_to_capture_gpu = ori_func
|
||||
vllm.v1.worker.gpu.cudagraph_utils.prepare_inputs_to_capture = ori
|
||||
|
||||
|
||||
def prepare_inputs_to_capture(
|
||||
@@ -83,9 +182,66 @@ def prepare_inputs_to_capture(
|
||||
num_tokens: int,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
# TODO(Ronald1995): Implement NPU specific input preparation.
|
||||
return {}
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
|
||||
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
|
||||
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
|
||||
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
input_buffers.seq_lens_cpu[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens_cpu[num_reqs:] = 0
|
||||
|
||||
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, kv_cache_config)
|
||||
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
seq_lens_np=input_buffers.seq_lens_np,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
|
||||
class ModelWithContext(nn.Module):
|
||||
"""Define a wrapper model to inject forward context.
|
||||
so we can inherit vllm's CudaGraphManager._capture_full_graph.
|
||||
"""
|
||||
|
||||
def __init__(self, original_model):
|
||||
super().__init__()
|
||||
self.original_model = original_model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
# In warmup phase, capturing=False by default.
|
||||
# when capturing, we need to set capturing=True in forward context.
|
||||
_EXTRA_CTX.capturing = True
|
||||
|
||||
return self.original_model(*args, **kwargs)
|
||||
|
||||
@@ -23,8 +23,8 @@ from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
@@ -43,7 +43,7 @@ def get_attn_mask_builder(device: torch.device):
|
||||
|
||||
def build_attn_metadata(
|
||||
*,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
@@ -54,6 +54,7 @@ def build_attn_metadata(
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np: np.ndarray | None = None,
|
||||
num_computed_tokens_cpu: torch.Tensor | None = None,
|
||||
@@ -72,9 +73,6 @@ def build_attn_metadata(
|
||||
if seq_lens_np is None:
|
||||
seq_lens_np = np.full(num_reqs, max_seq_len, dtype=np.int32)
|
||||
seq_lens_cpu = torch.from_numpy(seq_lens_np)[:num_reqs]
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings = slot_mappings.to(torch.int32)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
kv_cache_groups = kv_cache_config.kv_cache_groups
|
||||
@@ -100,13 +98,14 @@ def build_attn_metadata(
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builders[i]
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata, # type: ignore
|
||||
)
|
||||
for layer_name in kv_cache_spec.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
for attn_group in attn_groups[i]:
|
||||
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||
metadata = attn_metadata_builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = metadata
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
58
vllm_ascend/worker/v2/block_table.py
Normal file
58
vllm_ascend/worker/v2/block_table.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/block_table.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
|
||||
|
||||
class AscendBlockTables(BlockTables):
|
||||
"""Block table for Ascend NPUs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
cp_size: int = 1,
|
||||
cp_rank: int = 0,
|
||||
cp_interleave: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
block_sizes,
|
||||
max_num_reqs,
|
||||
max_num_batched_tokens,
|
||||
max_model_len,
|
||||
device,
|
||||
cp_size,
|
||||
cp_rank,
|
||||
cp_interleave,
|
||||
)
|
||||
# because we will override these attribute, delete these attribute to
|
||||
# make sure it's collected by python gc immediately.
|
||||
del self.slot_mappings
|
||||
# vllm-ascend' reshape_and_cache function requires slot_mappings to be int32.
|
||||
# so we need to redefine slot_mappings to be int32.
|
||||
self.slot_mappings: torch.Tensor = torch.zeros(
|
||||
self.num_kv_cache_groups,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -22,6 +22,8 @@ import numpy as np
|
||||
import torch
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
|
||||
|
||||
class AscendInputBuffers(InputBuffers):
|
||||
"""Input buffers for Ascend NPUs."""
|
||||
@@ -37,6 +39,16 @@ class AscendInputBuffers(InputBuffers):
|
||||
max_num_tokens,
|
||||
device,
|
||||
)
|
||||
del self.query_start_loc
|
||||
|
||||
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
|
||||
# See _pad_query_start_loc_for_fia.
|
||||
self.query_start_loc: torch.Tensor = torch.zeros(
|
||||
max_num_reqs + 2,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Create seq_lens_cpu and seq_lens_np.
|
||||
# npu's attention backend still needs seq_lens on CPU side.
|
||||
self.seq_lens_cpu: torch.Tensor = torch.zeros(
|
||||
@@ -56,6 +68,8 @@ class AscendInputBatch(InputBatch):
|
||||
# Create seq_lens_np.
|
||||
# npu's attention backend still needs seq_lens on CPU side.
|
||||
seq_lens_np: np.ndarray
|
||||
# attn_state is used to build attention metadata.
|
||||
attn_state: AscendAttentionState | None = None
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
@@ -79,4 +93,11 @@ class AscendInputBatch(InputBatch):
|
||||
input_buffers.seq_lens_np[num_reqs:] = 0
|
||||
seq_lens_np = input_buffers.seq_lens_np[:num_reqs]
|
||||
input_batch.seq_lens_np = seq_lens_np
|
||||
# A dummy run for dp or memory profiling.
|
||||
# When dummy run for dp, num_tokens is set to 1,
|
||||
# so attn_state is set to DecodeOnly.
|
||||
# when dummy run for memory profiling,
|
||||
# attention metadata isn't needed,
|
||||
# we can also set attn_state to AscendAttentionState.DecodeOnly.
|
||||
input_batch.attn_state = AscendAttentionState.DecodeOnly
|
||||
return cls(**asdict(input_batch), seq_lens_np=seq_lens_np)
|
||||
|
||||
@@ -17,12 +17,16 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
combine_sampled_and_draft_tokens,
|
||||
@@ -32,23 +36,23 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
)
|
||||
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import set_weight_prefetch_method
|
||||
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata, build_attn_state
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_state
|
||||
from vllm_ascend.worker.v2.input_batch import AscendInputBatch, AscendInputBuffers
|
||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||
from vllm_ascend.worker.v2.spec_decode import init_speculator
|
||||
from vllm_ascend.worker.v2.spec_decode.eagle import AscendEagleSpeculator
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from vllm_ascend.worker.v2.utils import block_table_wrapper, model_states_wrapper, torch_cuda_wrapper
|
||||
|
||||
|
||||
class NPUModelRunner(GPUModelRunner):
|
||||
"""Model runner for Ascend NPUs."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
with torch_cuda_wrapper():
|
||||
with torch_cuda_wrapper(), block_table_wrapper(), model_states_wrapper():
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# because we will override these attribute, delete these attribute to
|
||||
@@ -62,8 +66,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# NPU specific initializations can be added below.
|
||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
self,
|
||||
)
|
||||
|
||||
# we define AscendEagleSpeculator in vllm_ascend.worker.v2.spec_decode.eagle
|
||||
@@ -96,6 +101,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
req_states=self.req_states,
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
@@ -113,6 +119,59 @@ class NPUModelRunner(GPUModelRunner):
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Ascend-specific configurations
|
||||
self.ascend_config = get_ascend_config()
|
||||
# set this just the same as model runner v1, or it will raise error.
|
||||
set_weight_prefetch_method(self.ascend_config.weight_prefetch_config)
|
||||
|
||||
# we need to update full graph params in run_fullgraph,
|
||||
# so create a stream to update full graph params.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
|
||||
# we need to use return value of `get_cudagraph_and_dp_padding`
|
||||
# to set forward_context in `run_fullgraph`.
|
||||
# so we can inherit `execute_model` method.
|
||||
self.cudagraph_and_dp_padding: tuple[int, torch.Tensor | None, int] | None = None
|
||||
|
||||
# we need to use input_batch to set forward_context in run_fullgraph.
|
||||
# so we can inherit `execute_model` method.
|
||||
self.input_batch: AscendInputBatch | None = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
) -> ModelRunnerOutput | IntermediateTensors | None:
|
||||
"""Override GPUModelRunner.execute_model for Ascend NPUs by there reasons:
|
||||
1. when run fullgraph, we need to use ret value of `get_cudagraph_and_dp_padding`
|
||||
to set forward_context in `run_fullgraph`.
|
||||
"""
|
||||
|
||||
# use closure to store return value of get_cudagraph_and_dp_padding in model runner.
|
||||
def wrapper(func):
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
self.cudagraph_and_dp_padding = func(*args, **kwargs)
|
||||
return self.cudagraph_and_dp_padding
|
||||
|
||||
return inner
|
||||
|
||||
if self.cudagraph_and_dp_padding is None:
|
||||
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding = wrapper(
|
||||
vllm.v1.worker.gpu.model_runner.get_cudagraph_and_dp_padding
|
||||
)
|
||||
|
||||
return super().execute_model(
|
||||
scheduler_output,
|
||||
intermediate_tensors,
|
||||
dummy_run,
|
||||
skip_attn_for_dummy_run,
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -185,33 +244,40 @@ class NPUModelRunner(GPUModelRunner):
|
||||
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
|
||||
)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
# Get query_start_loc.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
# NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding.
|
||||
# See _pad_query_start_loc_for_fia.
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 2, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
|
||||
# Pad for full CUDA graph mode.
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
|
||||
# This is only required for vllm-ascend.
|
||||
query_start_loc_np, num_reqs_padded = self._pad_query_start_loc_for_fia(
|
||||
num_tokens_padded=num_tokens_after_padding,
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
|
||||
)
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens.
|
||||
prepare_prefill_inputs(
|
||||
self.input_buffers.input_ids,
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
prepare_prefill_inputs(
|
||||
self.input_buffers.input_ids,
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.all_token_ids.gpu,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
|
||||
# Prepare positions and seq_lens.
|
||||
prepare_pos_seq_lens(
|
||||
@@ -223,14 +289,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
# Prepare M-RoPE positions.
|
||||
if self.uses_mrope:
|
||||
self.mrope_states.prepare_mrope_positions(
|
||||
idx_mapping,
|
||||
query_start_loc,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens.gpu,
|
||||
)
|
||||
# Pad for full CUDA graph mode.
|
||||
self.input_buffers.seq_lens_np[num_reqs_padded:] = 0
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -246,43 +306,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, self.input_buffers.positions[:num_tokens]
|
||||
)
|
||||
# Layer name -> slot mapping.
|
||||
slot_mappings_by_layer = build_slot_mappings_by_layer(slot_mappings, self.kv_cache_config)
|
||||
# Layer name -> attention metadata.
|
||||
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in
|
||||
# vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs`
|
||||
# method like this.
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np=self.input_buffers.seq_lens_np,
|
||||
num_computed_tokens_cpu=self.req_states.num_computed_tokens_cpu[idx_mapping_cpu],
|
||||
attn_state=attn_state,
|
||||
)
|
||||
|
||||
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||
mrope_positions = None
|
||||
if self.uses_mrope:
|
||||
mrope_positions = self.mrope_states.mrope_positions
|
||||
mrope_positions = mrope_positions[:, :num_tokens_after_padding]
|
||||
return AscendInputBatch(
|
||||
|
||||
self.input_batch = AscendInputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
num_reqs=num_reqs_padded,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
@@ -294,18 +323,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
dcp_local_seq_lens=None, # TODO(Ronald1995): support cp.
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
inputs_embeds=None,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings_by_layer,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np=self.input_buffers.seq_lens_np,
|
||||
attn_state=attn_state,
|
||||
)
|
||||
return self.input_batch
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
@@ -352,7 +381,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.req_states.num_computed_tokens_cpu[req_index] = self.num_computed_tokens_cpu[req_index]
|
||||
|
||||
# update seq_lens_cpu
|
||||
for i, req_id in enumerate(req_ids):
|
||||
for i, req_id in enumerate(req_ids): # type: ignore
|
||||
req_index = self.req_states.req_id_to_index[req_id]
|
||||
num_computed_tokens = self.req_states.num_computed_tokens_cpu[req_index]
|
||||
self.input_buffers.seq_lens_cpu[i] = num_computed_tokens + num_scheduled_tokens[req_id]
|
||||
@@ -361,3 +390,44 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# TODO(Ronald1995): just define the method in case calling error in
|
||||
# worker, implement it in the future.
|
||||
pass
|
||||
|
||||
def _pad_query_start_loc_for_fia(
|
||||
self,
|
||||
num_tokens_padded: int,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
query_start_loc_np: np.ndarray,
|
||||
max_query_len: int,
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
This function is only designed to satisfied the constraint that when the layout is TND,
|
||||
the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`.
|
||||
"""
|
||||
assert self.cudagraph_and_dp_padding is not None
|
||||
_num_tokens_after_padding, _num_tokens_across_dp, synced_cudagraph_mode = self.cudagraph_and_dp_padding
|
||||
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_runtime_mode != CUDAGraphMode.FULL:
|
||||
return query_start_loc_np, num_reqs
|
||||
uniform_decode_query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
is_uniform_decode = self.cudagraph_manager.is_uniform_decode(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
)
|
||||
if is_uniform_decode:
|
||||
# Uniform-batch case: num_reqs must be no greater than num_reqs_padded
|
||||
num_reqs_padded = num_tokens_padded // uniform_decode_query_len
|
||||
|
||||
last_loc = query_start_loc_np[num_reqs]
|
||||
query_start_loc_np[num_reqs + 1 : num_reqs_padded + 1] = (
|
||||
np.arange(1, num_reqs_padded + 1 - num_reqs) * uniform_decode_query_len + last_loc
|
||||
)
|
||||
else:
|
||||
# Mixed-batch case: num_reqs must equal num_reqs_padded
|
||||
num_reqs_padded = min(num_tokens_padded, self.max_num_reqs)
|
||||
|
||||
# Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly
|
||||
query_start_loc_np[num_reqs_padded + 1] = num_tokens_padded
|
||||
num_reqs_padded = num_reqs_padded + 1
|
||||
|
||||
return query_start_loc_np, num_reqs_padded
|
||||
|
||||
34
vllm_ascend/worker/v2/model_states/__init__.py
Normal file
34
vllm_ascend/worker/v2/model_states/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/__init__.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
|
||||
|
||||
def init_asecnd_model_state(
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
from vllm_ascend.worker.v2.model_states.default import AscendModelState
|
||||
|
||||
return AscendModelState(vllm_config, model, encoder_cache, device)
|
||||
62
vllm_ascend/worker/v2/model_states/default.py
Normal file
62
vllm_ascend/worker/v2/model_states/default.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_states/default.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||
from vllm_ascend.worker.v2.input_batch import AscendInputBatch
|
||||
|
||||
|
||||
class AscendModelState(DefaultModelState):
|
||||
"""Model state for Ascend NPUs."""
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: AscendInputBatch,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
"""Override prepare_attn method because `build_attn_metadata` is different from vllm."""
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
# extra attributes for ascend npus.
|
||||
seq_lens_np=input_batch.seq_lens_np,
|
||||
attn_state=input_batch.attn_state,
|
||||
)
|
||||
return attn_metadata
|
||||
@@ -16,9 +16,6 @@
|
||||
#
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
|
||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
|
||||
@@ -53,21 +50,23 @@ class AscendSampler(Sampler):
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
|
||||
# Apply bad words masking in place.
|
||||
self.bad_words_state.apply_bad_words(
|
||||
logits,
|
||||
idx_mapping,
|
||||
idx_mapping_np,
|
||||
input_ids,
|
||||
expanded_local_pos,
|
||||
)
|
||||
|
||||
# Apply temperature in place.
|
||||
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
|
||||
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply min_p in place if any request has a non-zero min_p.
|
||||
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
|
||||
if do_min_p:
|
||||
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
|
||||
# Apply min_p in place.
|
||||
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply top_k and/or top_p. This might return a new tensor.
|
||||
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
|
||||
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
|
||||
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
|
||||
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
|
||||
if do_top_k or do_top_p:
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
# Apply top_k and/or top_p. This might or might not return a new tensor.
|
||||
logits = self.sampling_states.apply_top_k_top_p(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Sample the next token.
|
||||
sampled = gumbel_sample(
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
import vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.worker.v2.attn_utils import build_attn_metadata
|
||||
|
||||
@@ -56,13 +56,13 @@ class AscendRequestState(RequestState):
|
||||
self,
|
||||
req_id,
|
||||
prompt_len,
|
||||
prefill_token_ids,
|
||||
all_token_ids,
|
||||
num_computed_tokens,
|
||||
):
|
||||
super().add_request(
|
||||
req_id,
|
||||
prompt_len,
|
||||
prefill_token_ids,
|
||||
all_token_ids,
|
||||
num_computed_tokens,
|
||||
)
|
||||
req_idx = self.req_id_to_index[req_id]
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.worker.v2.block_table import AscendBlockTables
|
||||
from vllm_ascend.worker.v2.model_states import init_asecnd_model_state
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -15,6 +20,34 @@ def torch_cuda_wrapper():
|
||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||
torch.cuda.graph = torch.npu.graph
|
||||
torch.cuda.synchronize = torch.npu.synchronize
|
||||
torch.cuda.set_stream = torch.npu.set_stream
|
||||
torch.cuda.current_device = torch.npu.current_device
|
||||
torch.cuda.mem_get_info = torch.npu.mem_get_info
|
||||
logger.info_once("Wrapping torch.cuda with torch.npu.")
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def block_table_wrapper():
|
||||
try:
|
||||
# vllm-ascend need to initialize slot mapping as torch.int32 dtype,
|
||||
# but vllm default is torch.int64 dtype.
|
||||
vllm.v1.worker.gpu.model_runner.BlockTables = AscendBlockTables
|
||||
logger.info_once("Wrapping BlockTables with AscendBlockTables.")
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def model_states_wrapper():
|
||||
try:
|
||||
# prepare_attn in AscendModelState is different from vllm,
|
||||
# we need to override init_model_state.
|
||||
vllm.v1.worker.gpu.model_runner.init_model_state = init_asecnd_model_state
|
||||
logger.info_once("Wrapping init_model_state with init_asecnd_model_state.")
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user