[Feature] support eager mode in model runner v2 (#5210)
### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
2
.github/workflows/_e2e_test.yaml
vendored
2
.github/workflows/_e2e_test.yaml
vendored
@@ -125,6 +125,8 @@ jobs:
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
|
||||
|
||||
pytest -sv --durations=0 tests/e2e/singlecard/model_runner_v2/test_basic.py
|
||||
|
||||
e2e-2-cards:
|
||||
name: multicard-2
|
||||
runs-on: linux-aarch64-a3-2
|
||||
|
||||
0
tests/e2e/singlecard/model_runner_v2/__init__.py
Normal file
0
tests/e2e/singlecard/model_runner_v2/__init__.py
Normal file
51
tests/e2e/singlecard/model_runner_v2/test_basic.py
Normal file
51
tests/e2e/singlecard/model_runner_v2/test_basic.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
MODELS = ["Qwen/Qwen3-0.6B"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@patch.dict(os.environ, {"VLLM_USE_V2_MODEL_RUNNER": "1"})
|
||||
def test_qwen3_dense_eager_mode(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
) -> None:
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
enforce_eager=enforce_eager,
|
||||
) as runner:
|
||||
runner.model.generate(prompts, sampling_params)
|
||||
@@ -21,6 +21,7 @@ from typing import ClassVar, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
from vllm.attention.backends.registry import (AttentionBackendEnum,
|
||||
@@ -54,7 +55,10 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUSTOM"
|
||||
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
||||
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
||||
# rectify this when vllm disable the assertion.
|
||||
return "CUSTOM" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendAttentionBackendImpl"]:
|
||||
@@ -535,7 +539,10 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
# we inherit ForwardContext in model runner v2, when enable model
|
||||
# runner v2, there is not capturing attribute in forward_context,
|
||||
# just use getattr to avoid attribute error.
|
||||
if getattr(forward_context, "capturing", False):
|
||||
attn_output, num_tokens = self.full_graph_fia(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
@@ -53,7 +54,10 @@ class AscendMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_MLA"
|
||||
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
||||
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
||||
# rectify this when vllm disable the assertion.
|
||||
return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config
|
||||
@@ -44,7 +45,10 @@ class AscendSFABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_SFA"
|
||||
# HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make
|
||||
# attention name assertion, we just set name to FLASH_ATTN to avoid assertion error.
|
||||
# rectify this when vllm disable the assertion.
|
||||
return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
|
||||
@@ -81,7 +81,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
if not forward_context.sp_enabled:
|
||||
if not getattr(forward_context, "sp_enabled", False):
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
|
||||
6
vllm_ascend/worker/v2/README.md
Normal file
6
vllm_ascend/worker/v2/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# [Experimental] Model Runner V2
|
||||
|
||||
This directory contains the new model runner which is under active development.
|
||||
|
||||
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
|
||||
to get specific plans.
|
||||
@@ -1,5 +1,21 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,5 +1,22 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/input_batch.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_runner.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -19,7 +35,8 @@ from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
||||
build_attn_state,
|
||||
make_attention_mask)
|
||||
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState
|
||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -29,7 +46,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
"""Model runner for Ascend NPUs."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
with torch_cuda_wrapper():
|
||||
with (torch_cuda_wrapper(), uva_wrapper()):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# because we will override these attribute, delete these attribute to
|
||||
@@ -37,6 +54,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
del self.cudagraph_manager
|
||||
del self.req_states
|
||||
del self.input_buffers
|
||||
del self.sampler
|
||||
|
||||
# NPU specific initializations can be added below.
|
||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||
@@ -65,6 +83,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
# we need to adjust triton operators in sampler,
|
||||
# so reinitialize sampler here.
|
||||
self.sampler: AscendSampler = AscendSampler(
|
||||
logprobs_mode=self.model_config.logprobs_mode, )
|
||||
|
||||
# actual seq lengths for query (used in attention backends).
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
@@ -206,7 +228,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping_npu,
|
||||
query_start_loc_gpu,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
# use prefill_token_ids.copy_to_gpu() because npu doesn't
|
||||
# support uva buffer.
|
||||
self.req_states.prefill_token_ids.copy_to_gpu(),
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens,
|
||||
)
|
||||
@@ -255,7 +279,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_computed_tokens_cpu=self.req_states.
|
||||
num_computed_tokens_cpu[idx_mapping_cpu],
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings=slot_mappings.to(torch.int32),
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
attn_mask=attn_mask,
|
||||
@@ -344,3 +370,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
req_index]
|
||||
self.input_buffers.seq_lens_cpu[
|
||||
i] = num_computed_tokens + num_scheduled_tokens[req_id]
|
||||
|
||||
def eplb_warmup(self):
|
||||
# TODO(Ronald1995): just define the method in case calling error in
|
||||
# worker, implement it in the future.
|
||||
pass
|
||||
|
||||
0
vllm_ascend/worker/v2/sample/__init__.py
Normal file
0
vllm_ascend/worker/v2/sample/__init__.py
Normal file
128
vllm_ascend/worker/v2/sample/gumbel.py
Normal file
128
vllm_ascend/worker/v2/sample/gumbel.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
|
||||
# compiler doesn't support unint64 of pos arg.
|
||||
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
# NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32,
|
||||
# or triton-ascend's compiler will raise error.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float32)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
logits = logits / temp
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx,
|
||||
token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Override the function because there are some bugs
|
||||
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
|
||||
you could read NOTE(Ronald1995) comments to understand.
|
||||
"""
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
137
vllm_ascend/worker/v2/sample/penalties.py
Normal file
137
vllm_ascend/worker/v2/sample/penalties.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_and_temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
temperature_ptr,
|
||||
idx_mapping_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + batch_idx)
|
||||
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
|
||||
# use_freq_penalty or use_pres_penalty`,
|
||||
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
|
||||
# because triton-ascend's compiler doesn't support chained boolean operator.
|
||||
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
|
||||
use_temperature = temperature != 1.0
|
||||
if not (use_penalty or use_temperature):
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
if use_penalty:
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
output_bin_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride +
|
||||
block,
|
||||
mask=mask,
|
||||
)
|
||||
output_bin_mask = output_bin_counts > 0
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(
|
||||
0, BLOCK_SIZE // 32)
|
||||
packed_mask = tl.load(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride +
|
||||
packed_block,
|
||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
||||
)
|
||||
prompt_bin_mask = (packed_mask[:, None] >>
|
||||
(tl.arange(0, 32)[None, :])) & 1
|
||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty,
|
||||
1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * output_bin_mask
|
||||
|
||||
# Apply temperature.
|
||||
logits = logits / temperature
|
||||
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_penalties_and_temperature(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> None:
|
||||
"""Override the function because there are some bugs
|
||||
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
|
||||
you could read NOTE(Ronald1995) comments to understand.
|
||||
"""
|
||||
num_reqs, vocab_size = logits.shape
|
||||
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
|
||||
# in triton-ascend.
|
||||
BLOCK_SIZE = 4096
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
||||
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampling_metadata.repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.prompt_bin_mask,
|
||||
sampling_metadata.prompt_bin_mask.stride(0),
|
||||
sampling_metadata.output_bin_counts,
|
||||
sampling_metadata.output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
58
vllm_ascend/worker/v2/sample/sampler.py
Normal file
58
vllm_ascend/worker/v2/sample/sampler.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
|
||||
from vllm_ascend.worker.v2.sample.penalties import \
|
||||
apply_penalties_and_temperature
|
||||
|
||||
|
||||
class AscendSampler(Sampler):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Override sample method because we need to override triton operators
|
||||
called in the method.
|
||||
"""
|
||||
# Copy logits to a new FP32 tensor.
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply penalties and temperature in place.
|
||||
apply_penalties_and_temperature(logits, sampling_metadata)
|
||||
# Apply min_p in place.
|
||||
if sampling_metadata.min_p is not None:
|
||||
apply_min_p(logits, sampling_metadata.min_p)
|
||||
# Apply top_k and/or top_p. This might return a new tensor.
|
||||
logits = apply_top_k_top_p(logits, sampling_metadata.top_k,
|
||||
sampling_metadata.top_p)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits
|
||||
@@ -1,8 +1,28 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/states.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu.states import RequestState, UvaBuffer
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class AscendRequestState(RequestState):
|
||||
@@ -18,16 +38,15 @@ class AscendRequestState(RequestState):
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
with uva_wrapper():
|
||||
super().__init__(
|
||||
max_num_reqs,
|
||||
max_model_len,
|
||||
max_num_batched_tokens,
|
||||
num_speculative_steps,
|
||||
vocab_size,
|
||||
device,
|
||||
pin_memory,
|
||||
)
|
||||
super().__init__(
|
||||
max_num_reqs,
|
||||
max_model_len,
|
||||
max_num_batched_tokens,
|
||||
num_speculative_steps,
|
||||
vocab_size,
|
||||
device,
|
||||
pin_memory,
|
||||
)
|
||||
# because we will override these attribute, delete these attribute to
|
||||
# make sure it's collected by python gc immediately.
|
||||
del self.prefill_token_ids
|
||||
@@ -78,11 +97,9 @@ def uva_wrapper():
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# TODO(Ronald1995): rectify this when NPU support uva.
|
||||
global UvaBuffer
|
||||
ori_class = UvaBuffer
|
||||
try:
|
||||
UvaBuffer = UvaBufferWrapper
|
||||
# TODO(Ronald1995): rectify this when NPU support uva.
|
||||
vllm.v1.worker.gpu.states.UvaBuffer = UvaBufferWrapper
|
||||
yield
|
||||
finally:
|
||||
UvaBuffer = ori_class
|
||||
pass
|
||||
|
||||
@@ -5,29 +5,16 @@ import torch
|
||||
|
||||
@contextmanager
|
||||
def torch_cuda_wrapper():
|
||||
ori_event = torch.cuda.Event
|
||||
ori_stream = torch.cuda.Stream
|
||||
ori_default_stream = torch.cuda.default_stream
|
||||
ori_current_stream = torch.cuda.current_stream
|
||||
ori_graph_pool_handle = torch.cuda.graph_pool_handle
|
||||
ori_cuda_graph_cls = torch.cuda.CUDAGraph
|
||||
ori_cuda_graph_func = torch.cuda.graph
|
||||
try:
|
||||
torch.cuda.Event = torch.npu.Event
|
||||
torch.cuda.Stream = torch.npu.Stream
|
||||
torch.cuda.stream = torch.npu.stream
|
||||
torch.cuda.default_stream = torch.npu.default_stream
|
||||
torch.cuda.current_stream = torch.npu.current_stream
|
||||
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
|
||||
torch.cuda.CUDAGraph = torch.npu.NpuGraph
|
||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||
torch.cuda.graph = torch.npu.graph
|
||||
torch.cuda.synchronize = torch.npu.synchronize
|
||||
yield
|
||||
finally:
|
||||
# revert back torch cuda properties, so it will still raise error
|
||||
# to call cuda ops in npu environment.
|
||||
torch.cuda.Event = ori_event
|
||||
torch.cuda.Stream = ori_stream
|
||||
torch.cuda.default_stream = ori_default_stream
|
||||
torch.cuda.current_stream = ori_current_stream
|
||||
torch.cuda.graph_pool_handle = ori_graph_pool_handle
|
||||
torch.cuda.CUDAGraph = ori_cuda_graph_cls
|
||||
torch.cuda.graph = ori_cuda_graph_func
|
||||
pass
|
||||
|
||||
@@ -233,8 +233,8 @@ class NPUWorker(WorkerBase):
|
||||
self.device = self._init_device()
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
if self.use_v2_model_runner:
|
||||
logger.error(
|
||||
"npu model runner v2 is in developing, it can't work well for now."
|
||||
logger.warning(
|
||||
"npu model runner v2 is in developing, some features doesn't work for now."
|
||||
)
|
||||
from vllm_ascend.worker.v2.model_runner import \
|
||||
NPUModelRunner as NPUModelRunnerV2
|
||||
|
||||
Reference in New Issue
Block a user