diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 7418c60a..c28b3536 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -125,6 +125,8 @@ jobs: pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py + pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py + e2e-2-cards: name: multicard-2 runs-on: linux-aarch64-a3-2 diff --git a/tests/e2e/singlecard/model_runner_v2/__init__.py b/tests/e2e/singlecard/model_runner_v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/singlecard/model_runner_v2/test_basic.py b/tests/e2e/singlecard/model_runner_v2/test_basic.py new file mode 100644 index 00000000..83bbe898 --- /dev/null +++ b/tests/e2e/singlecard/model_runner_v2/test_basic.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from unittest.mock import patch + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner + +MODELS = ["Qwen/Qwen3-0.6B"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("enforce_eager", [True]) +@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"}) +def test_qwen3_dense_eager_mode( + model: str, + max_tokens: int, + enforce_eager: bool, +) -> None: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0) + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=enforce_eager, + ) as runner: + runner.model.generate(prompts, sampling_params) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index a87ed516..ecd80da3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type import torch import torch_npu +import vllm.envs as envs_vllm from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.registry import (AttentionBackendEnum, @@ -54,7 +55,10 @@ class AscendAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "CUSTOM" + # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make + # attention name assertion, we just set name to FLASH_ATTN to avoid assertion error. + # rectify this when vllm disable the assertion. + return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod def get_impl_cls() -> Type["AscendAttentionBackendImpl"]: @@ -535,7 +539,10 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata: AscendMetadata, output: torch.Tensor): forward_context: ForwardContext = get_forward_context() - if forward_context.capturing: + # we inherit ForwardContext in model runner v2, when enable model + # runner v2, there is not capturing attribute in forward_context, + # just use getattr to avoid attribute error. + if getattr(forward_context, "capturing", False): attn_output, num_tokens = self.full_graph_fia( query, key, value, attn_metadata, output) output[:num_tokens] = attn_output[:num_tokens] diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index ee51b076..8b95e20f 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, import numpy as np import torch import torch_npu +import vllm.envs as envs_vllm from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig, get_current_vllm_config @@ -53,7 +54,10 @@ class AscendMLABackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ASCEND_MLA" + # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make + # attention name assertion, we just set name to FLASH_ATTN to avoid assertion error. + # rectify this when vllm disable the assertion. + return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod def get_builder_cls(): diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 8c3a2226..38a13e00 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar import torch import torch_npu +import vllm.envs as envs_vllm from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config @@ -44,7 +45,10 @@ class AscendSFABackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ASCEND_SFA" + # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make + # attention name assertion, we just set name to FLASH_ATTN to avoid assertion error. + # rectify this when vllm disable the assertion. + return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod def get_builder_cls(): diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index 8403f438..4e44d8cf 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -81,7 +81,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, except AssertionError: return tensor_model_parallel_all_reduce(x) - if not forward_context.sp_enabled: + if not getattr(forward_context, "sp_enabled", False): return tensor_model_parallel_all_reduce(x) dp_metadata = forward_context.dp_metadata diff --git a/vllm_ascend/worker/v2/README.md b/vllm_ascend/worker/v2/README.md new file mode 100644 index 00000000..3a83901d --- /dev/null +++ b/vllm_ascend/worker/v2/README.md @@ -0,0 +1,6 @@ +# [Experimental] Model Runner V2 + +This directory contains the new model runner which is under active development. + +please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208) +to get specific plans. diff --git a/vllm_ascend/worker/v2/aclgraph_utils.py b/vllm_ascend/worker/v2/aclgraph_utils.py index b6460fd0..1fab82d2 100644 --- a/vllm_ascend/worker/v2/aclgraph_utils.py +++ b/vllm_ascend/worker/v2/aclgraph_utils.py @@ -1,5 +1,21 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# from contextlib import contextmanager from typing import Any diff --git a/vllm_ascend/worker/v2/attn_utils.py b/vllm_ascend/worker/v2/attn_utils.py index a66e5a21..655c0369 100644 --- a/vllm_ascend/worker/v2/attn_utils.py +++ b/vllm_ascend/worker/v2/attn_utils.py @@ -1,5 +1,22 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from collections.abc import Sequence from typing import Any diff --git a/vllm_ascend/worker/v2/input_batch.py b/vllm_ascend/worker/v2/input_batch.py index 843ec658..e3b87cf2 100644 --- a/vllm_ascend/worker/v2/input_batch.py +++ b/vllm_ascend/worker/v2/input_batch.py @@ -1,3 +1,22 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/input_batch.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + import numpy as np import torch from vllm.v1.worker.gpu.input_batch import InputBuffers diff --git a/vllm_ascend/worker/v2/model_runner.py b/vllm_ascend/worker/v2/model_runner.py index f85304d6..447fdf8d 100644 --- a/vllm_ascend/worker/v2/model_runner.py +++ b/vllm_ascend/worker/v2/model_runner.py @@ -1,5 +1,21 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_runner.py # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# import numpy as np import torch @@ -19,7 +35,8 @@ from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata, build_attn_state, make_attention_mask) from vllm_ascend.worker.v2.input_batch import AscendInputBuffers -from vllm_ascend.worker.v2.states import AscendRequestState +from vllm_ascend.worker.v2.sample.sampler import AscendSampler +from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper from vllm_ascend.worker.v2.utils import torch_cuda_wrapper logger = init_logger(__name__) @@ -29,7 +46,7 @@ class NPUModelRunner(GPUModelRunner): """Model runner for Ascend NPUs.""" def __init__(self, vllm_config: VllmConfig, device: torch.device): - with torch_cuda_wrapper(): + with (torch_cuda_wrapper(), uva_wrapper()): super().__init__(vllm_config, device) # because we will override these attribute, delete these attribute to @@ -37,6 +54,7 @@ class NPUModelRunner(GPUModelRunner): del self.cudagraph_manager del self.req_states del self.input_buffers + del self.sampler # NPU specific initializations can be added below. self.cudagraph_manager: AclGraphManager = AclGraphManager( @@ -65,6 +83,10 @@ class NPUModelRunner(GPUModelRunner): device=self.device, pin_memory=self.pin_memory, ) + # we need to adjust triton operators in sampler, + # so reinitialize sampler here. + self.sampler: AscendSampler = AscendSampler( + logprobs_mode=self.model_config.logprobs_mode, ) # actual seq lengths for query (used in attention backends). self.actual_seq_lengths_q: list[int] = [] @@ -206,7 +228,9 @@ class NPUModelRunner(GPUModelRunner): self.req_states.next_prefill_tokens, idx_mapping_npu, query_start_loc_gpu, - self.req_states.prefill_token_ids.gpu, + # use prefill_token_ids.copy_to_gpu() because npu doesn't + # support uva buffer. + self.req_states.prefill_token_ids.copy_to_gpu(), self.req_states.prefill_len.gpu, self.req_states.num_computed_tokens, ) @@ -255,7 +279,9 @@ class NPUModelRunner(GPUModelRunner): num_computed_tokens_cpu=self.req_states. num_computed_tokens_cpu[idx_mapping_cpu], block_tables=block_tables, - slot_mappings=slot_mappings, + # torch_npu._reshape_and_cache operator requires slot_mappings to + # be torch.int32. + slot_mappings=slot_mappings.to(torch.int32), kv_cache_config=self.kv_cache_config, decode_token_per_req=self.decode_token_per_req, attn_mask=attn_mask, @@ -344,3 +370,8 @@ class NPUModelRunner(GPUModelRunner): req_index] self.input_buffers.seq_lens_cpu[ i] = num_computed_tokens + num_scheduled_tokens[req_id] + + def eplb_warmup(self): + # TODO(Ronald1995): just define the method in case calling error in + # worker, implement it in the future. + pass diff --git a/vllm_ascend/worker/v2/sample/__init__.py b/vllm_ascend/worker/v2/sample/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/worker/v2/sample/gumbel.py b/vllm_ascend/worker/v2/sample/gumbel.py new file mode 100644 index 00000000..838711a6 --- /dev/null +++ b/vllm_ascend/worker/v2/sample/gumbel.py @@ -0,0 +1,128 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.triton_utils import tl, triton + + +@triton.jit +def _gumbel_sample_kernel( + local_argmax_ptr, + local_argmax_stride, + local_max_ptr, + local_max_stride, + logits_ptr, + logits_stride, + seeds_ptr, + pos_ptr, + temp_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, + APPLY_TEMPERATURE: tl.constexpr, +): + req_idx = tl.program_id(0) + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, + mask=mask, + other=float("-inf"), + ) + logits = logits.to(tl.float32) + + temp = tl.load(temp_ptr + req_idx).to(tl.float32) + if temp != 0.0: + # Calculate the seed for gumbel noise. + seed = tl.load(seeds_ptr + req_idx) + # NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's + # compiler doesn't support unint64 of pos arg. + pos = tl.load(pos_ptr + req_idx).to(tl.int32) + gumbel_seed = tl.randint(seed, pos) + + # Generate gumbel noise. + # NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32, + # or triton-ascend's compiler will raise error. + r = tl.rand(gumbel_seed, block).to(tl.float32) + gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) + gumbel_noise = gumbel_noise.to(tl.float32) + + # Apply temperature. + if APPLY_TEMPERATURE: + # NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel. + # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. + logits = logits / temp + + # Apply gumbel noise. + logits = tl.where(mask, logits + gumbel_noise, float("-inf")) + + idx = tl.argmax(logits, axis=0) + token_id = block_idx * BLOCK_SIZE + idx + value = tl.max(logits, axis=0) + tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, + token_id) + tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value) + + +def gumbel_sample( + logits: torch.Tensor, # [num_reqs, vocab_size] + temperature: torch.Tensor, # [num_reqs] + seed: torch.Tensor, # [num_reqs] + pos: torch.Tensor, # [num_reqs] + apply_temperature: bool, +) -> torch.Tensor: + """Override the function because there are some bugs + when _gumbel_sample_kernel runs on npu, we need to make some fixes. + you could read NOTE(Ronald1995) comments to understand. + """ + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + local_argmax = torch.empty( + num_reqs, + num_blocks, + dtype=torch.int64, + device=logits.device, + ) + local_max = torch.empty( + num_reqs, + num_blocks, + dtype=torch.float32, + device=logits.device, + ) + # TODO(Ronald1995): Optimize the performance of the kernel in npu. + _gumbel_sample_kernel[(num_reqs, num_blocks)]( + local_argmax, + local_argmax.stride(0), + local_max, + local_max.stride(0), + logits, + logits.stride(0), + seed, + pos, + temperature, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + APPLY_TEMPERATURE=apply_temperature, + ) + # NOTE(woosuk): Use int64 for later indexing. + max_block_idx = local_max.argmax(dim=-1, keepdim=True) + sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) + return sampled diff --git a/vllm_ascend/worker/v2/sample/penalties.py b/vllm_ascend/worker/v2/sample/penalties.py new file mode 100644 index 00000000..fe730a9d --- /dev/null +++ b/vllm_ascend/worker/v2/sample/penalties.py @@ -0,0 +1,137 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata + + +@triton.jit +def _penalties_and_temperature_kernel( + logits_ptr, + logits_stride, + repetition_penalty_ptr, + frequency_penalty_ptr, + presence_penalty_ptr, + temperature_ptr, + idx_mapping_ptr, + prompt_bin_mask_ptr, + prompt_bin_mask_stride, + output_bin_counts_ptr, + output_bin_counts_stride, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) + freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) + pres_penalty = tl.load(presence_penalty_ptr + batch_idx) + temperature = tl.load(temperature_ptr + batch_idx) + temperature = tl.where(temperature == 0.0, 1.0, temperature) + + use_rep_penalty = rep_penalty != 1.0 + use_freq_penalty = freq_penalty != 0.0 + use_pres_penalty = pres_penalty != 0.0 + # NOTE(Ronald1995): vllm original grammar `use_rep_penalty or + # use_freq_penalty or use_pres_penalty`, + # change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`, + # because triton-ascend's compiler doesn't support chained boolean operator. + use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty + use_temperature = temperature != 1.0 + if not (use_penalty or use_temperature): + # Early return to avoid loading logits. + return + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) + logits = logits.to(tl.float32) + + if use_penalty: + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + output_bin_counts = tl.load( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + + block, + mask=mask, + ) + output_bin_mask = output_bin_counts > 0 + + # Apply repetition penalties. + if use_rep_penalty: + packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange( + 0, BLOCK_SIZE // 32) + packed_mask = tl.load( + prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + + packed_block, + mask=packed_block < tl.cdiv(vocab_size, 32), + ) + prompt_bin_mask = (packed_mask[:, None] >> + (tl.arange(0, 32)[None, :])) & 1 + prompt_bin_mask = prompt_bin_mask.to(tl.int1) + prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE) + + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + logits *= tl.where(logits > 0, 1.0 / scale, scale) + + # Apply frequency penalties. + logits -= freq_penalty * output_bin_counts + # Apply presence penalties. + logits -= pres_penalty * output_bin_mask + + # Apply temperature. + logits = logits / temperature + + # Store back to logits. + tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) + + +def apply_penalties_and_temperature( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> None: + """Override the function because there are some bugs + when _penalties_and_temperature_kernel runs on npu, we need to make some fixes. + you could read NOTE(Ronald1995) comments to understand. + """ + num_reqs, vocab_size = logits.shape + # NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow + # in triton-ascend. + BLOCK_SIZE = 4096 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + # TODO(Ronald1995): Optimize the performance of the kernel in npu. + _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( + logits, + logits.stride(0), + sampling_metadata.repetition_penalty, + sampling_metadata.frequency_penalty, + sampling_metadata.presence_penalty, + sampling_metadata.temperature, + sampling_metadata.idx_mapping, + sampling_metadata.prompt_bin_mask, + sampling_metadata.prompt_bin_mask.stride(0), + sampling_metadata.output_bin_counts, + sampling_metadata.output_bin_counts.stride(0), + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/vllm_ascend/worker/v2/sample/sampler.py b/vllm_ascend/worker/v2/sample/sampler.py new file mode 100644 index 00000000..e54536c7 --- /dev/null +++ b/vllm_ascend/worker/v2/sample/sampler.py @@ -0,0 +1,58 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import torch +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.min_p import apply_min_p +from vllm.v1.worker.gpu.sample.sampler import Sampler + +from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample +from vllm_ascend.worker.v2.sample.penalties import \ + apply_penalties_and_temperature + + +class AscendSampler(Sampler): + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Override sample method because we need to override triton operators + called in the method. + """ + # Copy logits to a new FP32 tensor. + logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) + + # Apply penalties and temperature in place. + apply_penalties_and_temperature(logits, sampling_metadata) + # Apply min_p in place. + if sampling_metadata.min_p is not None: + apply_min_p(logits, sampling_metadata.min_p) + # Apply top_k and/or top_p. This might return a new tensor. + logits = apply_top_k_top_p(logits, sampling_metadata.top_k, + sampling_metadata.top_p) + + sampled = gumbel_sample( + logits, + sampling_metadata.temperature, + sampling_metadata.seeds, + sampling_metadata.pos, + apply_temperature=False, + ) + return sampled, logits diff --git a/vllm_ascend/worker/v2/states.py b/vllm_ascend/worker/v2/states.py index 1364c869..d4eee5f9 100644 --- a/vllm_ascend/worker/v2/states.py +++ b/vllm_ascend/worker/v2/states.py @@ -1,8 +1,28 @@ +# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/states.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + from contextlib import contextmanager import torch +import vllm from vllm.v1.utils import CpuGpuBuffer -from vllm.v1.worker.gpu.states import RequestState, UvaBuffer +from vllm.v1.worker.gpu.states import RequestState class AscendRequestState(RequestState): @@ -18,16 +38,15 @@ class AscendRequestState(RequestState): device: torch.device, pin_memory: bool, ): - with uva_wrapper(): - super().__init__( - max_num_reqs, - max_model_len, - max_num_batched_tokens, - num_speculative_steps, - vocab_size, - device, - pin_memory, - ) + super().__init__( + max_num_reqs, + max_model_len, + max_num_batched_tokens, + num_speculative_steps, + vocab_size, + device, + pin_memory, + ) # because we will override these attribute, delete these attribute to # make sure it's collected by python gc immediately. del self.prefill_token_ids @@ -78,11 +97,9 @@ def uva_wrapper(): def __init__(self, *args, **kwargs): pass - # TODO(Ronald1995): rectify this when NPU support uva. - global UvaBuffer - ori_class = UvaBuffer try: - UvaBuffer = UvaBufferWrapper + # TODO(Ronald1995): rectify this when NPU support uva. + vllm.v1.worker.gpu.states.UvaBuffer = UvaBufferWrapper yield finally: - UvaBuffer = ori_class + pass diff --git a/vllm_ascend/worker/v2/utils.py b/vllm_ascend/worker/v2/utils.py index 20efb75f..0c28b2fb 100644 --- a/vllm_ascend/worker/v2/utils.py +++ b/vllm_ascend/worker/v2/utils.py @@ -5,29 +5,16 @@ import torch @contextmanager def torch_cuda_wrapper(): - ori_event = torch.cuda.Event - ori_stream = torch.cuda.Stream - ori_default_stream = torch.cuda.default_stream - ori_current_stream = torch.cuda.current_stream - ori_graph_pool_handle = torch.cuda.graph_pool_handle - ori_cuda_graph_cls = torch.cuda.CUDAGraph - ori_cuda_graph_func = torch.cuda.graph try: torch.cuda.Event = torch.npu.Event torch.cuda.Stream = torch.npu.Stream + torch.cuda.stream = torch.npu.stream torch.cuda.default_stream = torch.npu.default_stream torch.cuda.current_stream = torch.npu.current_stream torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle - torch.cuda.CUDAGraph = torch.npu.NpuGraph + torch.cuda.CUDAGraph = torch.npu.NPUGraph torch.cuda.graph = torch.npu.graph + torch.cuda.synchronize = torch.npu.synchronize yield finally: - # revert back torch cuda properties, so it will still raise error - # to call cuda ops in npu environment. - torch.cuda.Event = ori_event - torch.cuda.Stream = ori_stream - torch.cuda.default_stream = ori_default_stream - torch.cuda.current_stream = ori_current_stream - torch.cuda.graph_pool_handle = ori_graph_pool_handle - torch.cuda.CUDAGraph = ori_cuda_graph_cls - torch.cuda.graph = ori_cuda_graph_func + pass diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 324cbed0..f061ba46 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -233,8 +233,8 @@ class NPUWorker(WorkerBase): self.device = self._init_device() # Init ModelRunner here, so that we have access to self.device. if self.use_v2_model_runner: - logger.error( - "npu model runner v2 is in developing, it can't work well for now." + logger.warning( + "npu model runner v2 is in developing, some features doesn't work for now." ) from vllm_ascend.worker.v2.model_runner import \ NPUModelRunner as NPUModelRunnerV2