[Feature] support eager mode in model runner v2 (#5210)

### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208

- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2025-12-29 15:28:34 +08:00
committed by GitHub
parent 4da46da9bf
commit e7e1a7dc05
19 changed files with 528 additions and 44 deletions

View File

@@ -125,6 +125,8 @@ jobs:
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py
e2e-2-cards: e2e-2-cards:
name: multicard-2 name: multicard-2
runs-on: linux-aarch64-a3-2 runs-on: linux-aarch64-a3-2

View File

@@ -0,0 +1,51 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from unittest.mock import patch
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
MODELS = ["Qwen/Qwen3-0.6B"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("enforce_eager", [True])
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
def test_qwen3_dense_eager_mode(
model: str,
max_tokens: int,
enforce_eager: bool,
) -> None:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=enforce_eager,
) as runner:
runner.model.generate(prompts, sampling_params)

View File

@@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.attention.backends.registry import (AttentionBackendEnum, from vllm.attention.backends.registry import (AttentionBackendEnum,
@@ -54,7 +55,10 @@ class AscendAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "CUSTOM" # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
@@ -535,7 +539,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
attn_metadata: AscendMetadata, attn_metadata: AscendMetadata,
output: torch.Tensor): output: torch.Tensor):
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if forward_context.capturing: # we inherit ForwardContext in model runner v2, when enable model
# runner v2, there is not capturing attribute in forward_context,
# just use getattr to avoid attribute error.
if getattr(forward_context, "capturing", False):
attn_output, num_tokens = self.full_graph_fia( attn_output, num_tokens = self.full_graph_fia(
query, key, value, attn_metadata, output) query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens] output[:num_tokens] = attn_output[:num_tokens]

View File

@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
import numpy as np import numpy as np
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
@@ -53,7 +54,10 @@ class AscendMLABackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ASCEND_MLA" # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod @staticmethod
def get_builder_cls(): def get_builder_cls():

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
import torch import torch
import torch_npu import torch_npu
import vllm.envs as envs_vllm
from torch import nn from torch import nn
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
@@ -44,7 +45,10 @@ class AscendSFABackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ASCEND_SFA" # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
# rectify this when vllm disable the assertion.
return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
@staticmethod @staticmethod
def get_builder_cls(): def get_builder_cls():

View File

@@ -81,7 +81,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor,
except AssertionError: except AssertionError:
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
if not forward_context.sp_enabled: if not getattr(forward_context, "sp_enabled", False):
return tensor_model_parallel_all_reduce(x) return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata dp_metadata = forward_context.dp_metadata

View File

@@ -0,0 +1,6 @@
# [Experimental] Model Runner V2
This directory contains the new model runner which is under active development.
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
to get specific plans.

View File

@@ -1,5 +1,21 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any

View File

@@ -1,5 +1,22 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any

View File

@@ -1,3 +1,22 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/input_batch.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import numpy as np import numpy as np
import torch import torch
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers

View File

@@ -1,5 +1,21 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_runner.py
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import numpy as np import numpy as np
import torch import torch
@@ -19,7 +35,8 @@ from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state, build_attn_state,
make_attention_mask) make_attention_mask)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.states import AscendRequestState from vllm_ascend.worker.v2.sample.sampler import AscendSampler
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -29,7 +46,7 @@ class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs.""" """Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device): def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper(): with (torch_cuda_wrapper(), uva_wrapper()):
super().__init__(vllm_config, device) super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to # because we will override these attribute, delete these attribute to
@@ -37,6 +54,7 @@ class NPUModelRunner(GPUModelRunner):
del self.cudagraph_manager del self.cudagraph_manager
del self.req_states del self.req_states
del self.input_buffers del self.input_buffers
del self.sampler
# NPU specific initializations can be added below. # NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager( self.cudagraph_manager: AclGraphManager = AclGraphManager(
@@ -65,6 +83,10 @@ class NPUModelRunner(GPUModelRunner):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# we need to adjust triton operators in sampler,
# so reinitialize sampler here.
self.sampler: AscendSampler = AscendSampler(
logprobs_mode=self.model_config.logprobs_mode, )
# actual seq lengths for query (used in attention backends). # actual seq lengths for query (used in attention backends).
self.actual_seq_lengths_q: list[int] = [] self.actual_seq_lengths_q: list[int] = []
@@ -206,7 +228,9 @@ class NPUModelRunner(GPUModelRunner):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
idx_mapping_npu, idx_mapping_npu,
query_start_loc_gpu, query_start_loc_gpu,
self.req_states.prefill_token_ids.gpu, # use prefill_token_ids.copy_to_gpu() because npu doesn't
# support uva buffer.
self.req_states.prefill_token_ids.copy_to_gpu(),
self.req_states.prefill_len.gpu, self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens, self.req_states.num_computed_tokens,
) )
@@ -255,7 +279,9 @@ class NPUModelRunner(GPUModelRunner):
num_computed_tokens_cpu=self.req_states. num_computed_tokens_cpu=self.req_states.
num_computed_tokens_cpu[idx_mapping_cpu], num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables, block_tables=block_tables,
slot_mappings=slot_mappings, # torch_npu._reshape_and_cache operator requires slot_mappings to
# be torch.int32.
slot_mappings=slot_mappings.to(torch.int32),
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req, decode_token_per_req=self.decode_token_per_req,
attn_mask=attn_mask, attn_mask=attn_mask,
@@ -344,3 +370,8 @@ class NPUModelRunner(GPUModelRunner):
req_index] req_index]
self.input_buffers.seq_lens_cpu[ self.input_buffers.seq_lens_cpu[
i] = num_computed_tokens + num_scheduled_tokens[req_id] i] = num_computed_tokens + num_scheduled_tokens[req_id]
def eplb_warmup(self):
# TODO(Ronald1995): just define the method in case calling error in
# worker, implement it in the future.
pass

View File

View File

@@ -0,0 +1,128 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
logits_ptr,
logits_stride,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + req_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_idx)
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
# compiler doesn't support unint64 of pos arg.
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise.
# NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32,
# or triton-ascend's compiler will raise error.
r = tl.rand(gumbel_seed, block).to(tl.float32)
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
gumbel_noise = gumbel_noise.to(tl.float32)
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
idx = tl.argmax(logits, axis=0)
token_id = block_idx * BLOCK_SIZE + idx
value = tl.max(logits, axis=0)
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx,
token_id)
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size]
temperature: torch.Tensor, # [num_reqs]
seed: torch.Tensor, # [num_reqs]
pos: torch.Tensor, # [num_reqs]
apply_temperature: bool,
) -> torch.Tensor:
"""Override the function because there are some bugs
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_reqs,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_reqs,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_gumbel_sample_kernel[(num_reqs, num_blocks)](
local_argmax,
local_argmax.stride(0),
local_max,
local_max.stride(0),
logits,
logits.stride(0),
seed,
pos,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
APPLY_TEMPERATURE=apply_temperature,
)
# NOTE(woosuk): Use int64 for later indexing.
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
return sampled

View File

@@ -0,0 +1,137 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
# use_freq_penalty or use_pres_penalty`,
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
# because triton-ascend's compiler doesn't support chained boolean operator.
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride +
block,
mask=mask,
)
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(
0, BLOCK_SIZE // 32)
packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride +
packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >>
(tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
"""Override the function because there are some bugs
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
you could read NOTE(Ronald1995) comments to understand.
"""
num_reqs, vocab_size = logits.shape
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
# in triton-ascend.
BLOCK_SIZE = 4096
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -0,0 +1,58 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py.
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
from vllm_ascend.worker.v2.sample.penalties import \
apply_penalties_and_temperature
class AscendSampler(Sampler):
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Override sample method because we need to override triton operators
called in the method.
"""
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(logits, sampling_metadata.top_k,
sampling_metadata.top_p)
sampled = gumbel_sample(
logits,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits

View File

@@ -1,8 +1,28 @@
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/states.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import vllm
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.states import RequestState, UvaBuffer from vllm.v1.worker.gpu.states import RequestState
class AscendRequestState(RequestState): class AscendRequestState(RequestState):
@@ -18,16 +38,15 @@ class AscendRequestState(RequestState):
device: torch.device, device: torch.device,
pin_memory: bool, pin_memory: bool,
): ):
with uva_wrapper(): super().__init__(
super().__init__( max_num_reqs,
max_num_reqs, max_model_len,
max_model_len, max_num_batched_tokens,
max_num_batched_tokens, num_speculative_steps,
num_speculative_steps, vocab_size,
vocab_size, device,
device, pin_memory,
pin_memory, )
)
# because we will override these attribute, delete these attribute to # because we will override these attribute, delete these attribute to
# make sure it's collected by python gc immediately. # make sure it's collected by python gc immediately.
del self.prefill_token_ids del self.prefill_token_ids
@@ -78,11 +97,9 @@ def uva_wrapper():
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
# TODO(Ronald1995): rectify this when NPU support uva.
global UvaBuffer
ori_class = UvaBuffer
try: try:
UvaBuffer = UvaBufferWrapper # TODO(Ronald1995): rectify this when NPU support uva.
vllm.v1.worker.gpu.states.UvaBuffer = UvaBufferWrapper
yield yield
finally: finally:
UvaBuffer = ori_class pass

View File

@@ -5,29 +5,16 @@ import torch
@contextmanager @contextmanager
def torch_cuda_wrapper(): def torch_cuda_wrapper():
ori_event = torch.cuda.Event
ori_stream = torch.cuda.Stream
ori_default_stream = torch.cuda.default_stream
ori_current_stream = torch.cuda.current_stream
ori_graph_pool_handle = torch.cuda.graph_pool_handle
ori_cuda_graph_cls = torch.cuda.CUDAGraph
ori_cuda_graph_func = torch.cuda.graph
try: try:
torch.cuda.Event = torch.npu.Event torch.cuda.Event = torch.npu.Event
torch.cuda.Stream = torch.npu.Stream torch.cuda.Stream = torch.npu.Stream
torch.cuda.stream = torch.npu.stream
torch.cuda.default_stream = torch.npu.default_stream torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
torch.cuda.CUDAGraph = torch.npu.NpuGraph torch.cuda.CUDAGraph = torch.npu.NPUGraph
torch.cuda.graph = torch.npu.graph torch.cuda.graph = torch.npu.graph
torch.cuda.synchronize = torch.npu.synchronize
yield yield
finally: finally:
# revert back torch cuda properties, so it will still raise error pass
# to call cuda ops in npu environment.
torch.cuda.Event = ori_event
torch.cuda.Stream = ori_stream
torch.cuda.default_stream = ori_default_stream
torch.cuda.current_stream = ori_current_stream
torch.cuda.graph_pool_handle = ori_graph_pool_handle
torch.cuda.CUDAGraph = ori_cuda_graph_cls
torch.cuda.graph = ori_cuda_graph_func

View File

@@ -233,8 +233,8 @@ class NPUWorker(WorkerBase):
self.device = self._init_device() self.device = self._init_device()
# Init ModelRunner here, so that we have access to self.device. # Init ModelRunner here, so that we have access to self.device.
if self.use_v2_model_runner: if self.use_v2_model_runner:
logger.error( logger.warning(
"npu model runner v2 is in developing, it can't work well for now." "npu model runner v2 is in developing, some features doesn't work for now."
) )
from vllm_ascend.worker.v2.model_runner import \ from vllm_ascend.worker.v2.model_runner import \
NPUModelRunner as NPUModelRunnerV2 NPUModelRunner as NPUModelRunnerV2