[Feature] support eager mode in model runner v2 (#5210)

### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2025-12-29 15:28:34 +08:00
committed by GitHub
parent 4da46da9bf
commit e7e1a7dc05
19 changed files with 528 additions and 44 deletions

View File

@@ -125,6 +125,8 @@ jobs:
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py
e2e-2-cards:
name: multicard-2
runs-on: linux-aarch64-a3-2

View File

@@ -0,0 +1,51 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from unittest.mock import patch
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
MODELS = ["Qwen/Qwen3-0.6B"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("enforce_eager", [True])
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
def test_qwen3_dense_eager_mode(
model: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=enforce_eager,
) as runner:
runner.model.generate(prompts, sampling_params)

View File

@@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type
import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.registry import (AttentionBackendEnum,
@@ -54,7 +55,10 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "CUSTOM"
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
@@ -535,7 +539,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata,
output: torch.Tensor):
forward_context: ForwardContext = get_forward_context()
if forward_context.capturing:
# we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False):
attn_output, num_tokens = self.full_graph_fia(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]

View File

@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
import numpy as np
import torch
import torch_npu
import vllm.envs as envs_vllm
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config
@@ -53,7 +54,10 @@ class AscendMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ASCEND_MLA"
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_builder_cls():

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
import torch
import torch_npu
import vllm.envs as envs_vllm
from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
@@ -44,7 +45,10 @@ class AscendSFABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ASCEND_SFA"
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod
def get_builder_cls():

View File

@@ -81,7 +81,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor,
except AssertionError:
return tensor_model_parallel_all_reduce(x)
if not forward_context.sp_enabled:
if not getattr(forward_context, "sp_enabled", False):
return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata

View File

@@ -0,0 +1,6 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
to get specific plans.

View File

@@ -1,5 +1,21 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager
from typing import Any

View File

@@ -1,5 +1,22 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from collections.abc import Sequence
from typing import Any

View File

@@ -1,3 +1,22 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/input_batch.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import numpy as np
import torch
from vllm.v1.worker.gpu.input_batch import InputBuffers

View File

@@ -1,5 +1,21 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_runner.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import numpy as np
import torch
@@ -19,7 +35,8 @@ from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state,
make_attention_mask)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.states import AscendRequestState
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
logger = init_logger(__name__)
@@ -29,7 +46,7 @@ class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper():
with (torch_cuda_wrapper(), uva_wrapper()):
super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to
@@ -37,6 +54,7 @@ class NPUModelRunner(GPUModelRunner):
del self.cudagraph_manager
del self.req_states
del self.input_buffers
del self.sampler
# NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager(
@@ -65,6 +83,10 @@ class NPUModelRunner(GPUModelRunner):
device=self.device,
pin_memory=self.pin_memory,
)
# we need to adjust triton operators in sampler,
# so reinitialize sampler here.
self.sampler: AscendSampler = AscendSampler(
logprobs_mode=self.model_config.logprobs_mode, )
# actual seq lengths for query (used in attention backends).
self.actual_seq_lengths_q: list[int] = []
@@ -206,7 +228,9 @@ class NPUModelRunner(GPUModelRunner):
self.req_states.next_prefill_tokens,
idx_mapping_npu,
query_start_loc_gpu,
self.req_states.prefill_token_ids.gpu,
# use prefill_token_ids.copy_to_gpu() because npu doesn't
# support uva buffer.
self.req_states.prefill_token_ids.copy_to_gpu(),
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
)
@@ -255,7 +279,9 @@ class NPUModelRunner(GPUModelRunner):
num_computed_tokens_cpu=self.req_states.
num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables,
slot_mappings=slot_mappings,
# torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings=slot_mappings.to(torch.int32),
kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req,
attn_mask=attn_mask,
@@ -344,3 +370,8 @@ class NPUModelRunner(GPUModelRunner):
req_index]
self.input_buffers.seq_lens_cpu[
i] = num_computed_tokens + num_scheduled_tokens[req_id]
def eplb_warmup(self):
# TODO(Ronald1995): just define the method in case calling error in
# worker, implement it in the future.
pass

View File

View File

@@ -0,0 +1,128 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
# compiler doesn't support unint64 of pos arg.
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
# NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32,
# or triton-ascend's compiler will raise error.
r = tl.rand(gumbel_seed, block).to(tl.float32)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx,
token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
"""Override the function because there are some bugs
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,137 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
# use_freq_penalty or use_pres_penalty`,
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
# because triton-ascend's compiler doesn't support chained boolean operator.
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride +
block,
mask=mask,
)
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(
0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride +
packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >>
(tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
"""Override the function because there are some bugs
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
# in triton-ascend.
BLOCK_SIZE = 4096
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,58 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
from vllm_ascend.worker.v2.sample.penalties import \
apply_penalties_and_temperature
class AscendSampler(Sampler):
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override sample method because we need to override triton operators
called in the method.
"""
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(logits, sampling_metadata.top_k,
sampling_metadata.top_p)
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -1,8 +1,28 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/states.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager
import torch
import vllm
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.states import RequestState, UvaBuffer
from vllm.v1.worker.gpu.states import RequestState
class AscendRequestState(RequestState):
@@ -18,16 +38,15 @@ class AscendRequestState(RequestState):
device: torch.device,
pin_memory: bool,
):
with uva_wrapper():
super().__init__(
max_num_reqs,
max_model_len,
max_num_batched_tokens,
num_speculative_steps,
vocab_size,
device,
pin_memory,
)
super().__init__(
max_num_reqs,
max_model_len,
max_num_batched_tokens,
num_speculative_steps,
vocab_size,
device,
pin_memory,
)
# because we will override these attribute, delete these attribute to
# make sure it's collected by python gc immediately.
del self.prefill_token_ids
@@ -78,11 +97,9 @@ def uva_wrapper():
def __init__(self, *args, **kwargs):
pass
# TODO(Ronald1995): rectify this when NPU support uva.
global UvaBuffer
ori_class = UvaBuffer
try:
UvaBuffer = UvaBufferWrapper
# TODO(Ronald1995): rectify this when NPU support uva.
vllm.v1.worker.gpu.states.UvaBuffer = UvaBufferWrapper
yield
finally:
UvaBuffer = ori_class
pass

View File

@@ -5,29 +5,16 @@ import torch
@contextmanager
def torch_cuda_wrapper():
ori_event = torch.cuda.Event
ori_stream = torch.cuda.Stream
ori_default_stream = torch.cuda.default_stream
ori_current_stream = torch.cuda.current_stream
ori_graph_pool_handle = torch.cuda.graph_pool_handle
ori_cuda_graph_cls = torch.cuda.CUDAGraph
ori_cuda_graph_func = torch.cuda.graph
try:
torch.cuda.Event = torch.npu.Event
torch.cuda.Stream = torch.npu.Stream
torch.cuda.stream = torch.npu.stream
torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
torch.cuda.CUDAGraph = torch.npu.NpuGraph
torch.cuda.CUDAGraph = torch.npu.NPUGraph
torch.cuda.graph = torch.npu.graph
torch.cuda.synchronize = torch.npu.synchronize
yield
finally:
# revert back torch cuda properties, so it will still raise error
# to call cuda ops in npu environment.
torch.cuda.Event = ori_event
torch.cuda.Stream = ori_stream
torch.cuda.default_stream = ori_default_stream
torch.cuda.current_stream = ori_current_stream
torch.cuda.graph_pool_handle = ori_graph_pool_handle
torch.cuda.CUDAGraph = ori_cuda_graph_cls
torch.cuda.graph = ori_cuda_graph_func
pass

View File

@@ -233,8 +233,8 @@ class NPUWorker(WorkerBase):
self.device = self._init_device()
# Init ModelRunner here, so that we have access to self.device.
if self.use_v2_model_runner:
logger.error(
"npu model runner v2 is in developing, it can't work well for now."
logger.warning(
"npu model runner v2 is in developing, some features doesn't work for now."
)
from vllm_ascend.worker.v2.model_runner import \
NPUModelRunner as NPUModelRunnerV2