[Feature] support eager mode in model runner v2 (#5210)
### What this PR does / why we need it?
#5051 only implement a basic framework for model runner v2, but there
are still some bugs for e2e functionality, this PR aim to enable basic
functionality.
model runner v2 plans:
https://github.com/vllm-project/vllm-ascend/issues/5208
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
6
vllm_ascend/worker/v2/README.md
Normal file
6
vllm_ascend/worker/v2/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# [Experimental] Model Runner V2
|
||||
|
||||
This directory contains the new model runner which is under active development.
|
||||
|
||||
please see [Model Runner V2](https://github.com/vllm-project/vllm-ascend/issues/5208)
|
||||
to get specific plans.
|
||||
@@ -1,5 +1,21 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/aclgraph_utils.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,5 +1,22 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/attn_utils.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/input_batch.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
@@ -1,5 +1,21 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/model_runner.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -19,7 +35,8 @@ from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
|
||||
build_attn_state,
|
||||
make_attention_mask)
|
||||
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState
|
||||
from vllm_ascend.worker.v2.sample.sampler import AscendSampler
|
||||
from vllm_ascend.worker.v2.states import AscendRequestState, uva_wrapper
|
||||
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -29,7 +46,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
"""Model runner for Ascend NPUs."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
with torch_cuda_wrapper():
|
||||
with (torch_cuda_wrapper(), uva_wrapper()):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
# because we will override these attribute, delete these attribute to
|
||||
@@ -37,6 +54,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
del self.cudagraph_manager
|
||||
del self.req_states
|
||||
del self.input_buffers
|
||||
del self.sampler
|
||||
|
||||
# NPU specific initializations can be added below.
|
||||
self.cudagraph_manager: AclGraphManager = AclGraphManager(
|
||||
@@ -65,6 +83,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
# we need to adjust triton operators in sampler,
|
||||
# so reinitialize sampler here.
|
||||
self.sampler: AscendSampler = AscendSampler(
|
||||
logprobs_mode=self.model_config.logprobs_mode, )
|
||||
|
||||
# actual seq lengths for query (used in attention backends).
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
@@ -206,7 +228,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.req_states.next_prefill_tokens,
|
||||
idx_mapping_npu,
|
||||
query_start_loc_gpu,
|
||||
self.req_states.prefill_token_ids.gpu,
|
||||
# use prefill_token_ids.copy_to_gpu() because npu doesn't
|
||||
# support uva buffer.
|
||||
self.req_states.prefill_token_ids.copy_to_gpu(),
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.num_computed_tokens,
|
||||
)
|
||||
@@ -255,7 +279,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_computed_tokens_cpu=self.req_states.
|
||||
num_computed_tokens_cpu[idx_mapping_cpu],
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
# torch_npu._reshape_and_cache operator requires slot_mappings to
|
||||
# be torch.int32.
|
||||
slot_mappings=slot_mappings.to(torch.int32),
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
attn_mask=attn_mask,
|
||||
@@ -344,3 +370,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
req_index]
|
||||
self.input_buffers.seq_lens_cpu[
|
||||
i] = num_computed_tokens + num_scheduled_tokens[req_id]
|
||||
|
||||
def eplb_warmup(self):
|
||||
# TODO(Ronald1995): just define the method in case calling error in
|
||||
# worker, implement it in the future.
|
||||
pass
|
||||
|
||||
0
vllm_ascend/worker/v2/sample/__init__.py
Normal file
0
vllm_ascend/worker/v2/sample/__init__.py
Normal file
128
vllm_ascend/worker/v2/sample/gumbel.py
Normal file
128
vllm_ascend/worker/v2/sample/gumbel.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/gumbel.py.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
# NOTE(Ronald1995): change pos's dtype to tl.int32, because triton-ascend's
|
||||
# compiler doesn't support unint64 of pos arg.
|
||||
pos = tl.load(pos_ptr + req_idx).to(tl.int32)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
# NOTE(Ronald1995): r is tl.float64 in vllm, change it to tl.float32,
|
||||
# or triton-ascend's compiler will raise error.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float32)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
logits = logits / temp
|
||||
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx,
|
||||
token_id)
|
||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Override the function because there are some bugs
|
||||
when _gumbel_sample_kernel runs on npu, we need to make some fixes.
|
||||
you could read NOTE(Ronald1995) comments to understand.
|
||||
"""
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_reqs,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
local_max,
|
||||
local_max.stride(0),
|
||||
logits,
|
||||
logits.stride(0),
|
||||
seed,
|
||||
pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
APPLY_TEMPERATURE=apply_temperature,
|
||||
)
|
||||
# NOTE(woosuk): Use int64 for later indexing.
|
||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||
return sampled
|
||||
137
vllm_ascend/worker/v2/sample/penalties.py
Normal file
137
vllm_ascend/worker/v2/sample/penalties.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/penalties.py.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _penalties_and_temperature_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
repetition_penalty_ptr,
|
||||
frequency_penalty_ptr,
|
||||
presence_penalty_ptr,
|
||||
temperature_ptr,
|
||||
idx_mapping_ptr,
|
||||
prompt_bin_mask_ptr,
|
||||
prompt_bin_mask_stride,
|
||||
output_bin_counts_ptr,
|
||||
output_bin_counts_stride,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||
temperature = tl.load(temperature_ptr + batch_idx)
|
||||
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
||||
|
||||
use_rep_penalty = rep_penalty != 1.0
|
||||
use_freq_penalty = freq_penalty != 0.0
|
||||
use_pres_penalty = pres_penalty != 0.0
|
||||
# NOTE(Ronald1995): vllm original grammar `use_rep_penalty or
|
||||
# use_freq_penalty or use_pres_penalty`,
|
||||
# change it to `(use_rep_penalty or use_freq_penalty) or use_pres_penalty`,
|
||||
# because triton-ascend's compiler doesn't support chained boolean operator.
|
||||
use_penalty = (use_rep_penalty or use_freq_penalty) or use_pres_penalty
|
||||
use_temperature = temperature != 1.0
|
||||
if not (use_penalty or use_temperature):
|
||||
# Early return to avoid loading logits.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
if use_penalty:
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
output_bin_counts = tl.load(
|
||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride +
|
||||
block,
|
||||
mask=mask,
|
||||
)
|
||||
output_bin_mask = output_bin_counts > 0
|
||||
|
||||
# Apply repetition penalties.
|
||||
if use_rep_penalty:
|
||||
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(
|
||||
0, BLOCK_SIZE // 32)
|
||||
packed_mask = tl.load(
|
||||
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride +
|
||||
packed_block,
|
||||
mask=packed_block < tl.cdiv(vocab_size, 32),
|
||||
)
|
||||
prompt_bin_mask = (packed_mask[:, None] >>
|
||||
(tl.arange(0, 32)[None, :])) & 1
|
||||
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
|
||||
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty,
|
||||
1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||
|
||||
# Apply frequency penalties.
|
||||
logits -= freq_penalty * output_bin_counts
|
||||
# Apply presence penalties.
|
||||
logits -= pres_penalty * output_bin_mask
|
||||
|
||||
# Apply temperature.
|
||||
logits = logits / temperature
|
||||
|
||||
# Store back to logits.
|
||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||
|
||||
|
||||
def apply_penalties_and_temperature(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> None:
|
||||
"""Override the function because there are some bugs
|
||||
when _penalties_and_temperature_kernel runs on npu, we need to make some fixes.
|
||||
you could read NOTE(Ronald1995) comments to understand.
|
||||
"""
|
||||
num_reqs, vocab_size = logits.shape
|
||||
# NOTE(Ronald1995): change BLOCK_SIZE from 8192 to 4096 in case UB overflow
|
||||
# in triton-ascend.
|
||||
BLOCK_SIZE = 4096
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
# TODO(Ronald1995): Optimize the performance of the kernel in npu.
|
||||
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
sampling_metadata.repetition_penalty,
|
||||
sampling_metadata.frequency_penalty,
|
||||
sampling_metadata.presence_penalty,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.idx_mapping,
|
||||
sampling_metadata.prompt_bin_mask,
|
||||
sampling_metadata.prompt_bin_mask.stride(0),
|
||||
sampling_metadata.output_bin_counts,
|
||||
sampling_metadata.output_bin_counts.stride(0),
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
58
vllm_ascend/worker/v2/sample/sampler.py
Normal file
58
vllm_ascend/worker/v2/sample/sampler.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/sampler.py.
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import torch
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
|
||||
from vllm_ascend.worker.v2.sample.gumbel import gumbel_sample
|
||||
from vllm_ascend.worker.v2.sample.penalties import \
|
||||
apply_penalties_and_temperature
|
||||
|
||||
|
||||
class AscendSampler(Sampler):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Override sample method because we need to override triton operators
|
||||
called in the method.
|
||||
"""
|
||||
# Copy logits to a new FP32 tensor.
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply penalties and temperature in place.
|
||||
apply_penalties_and_temperature(logits, sampling_metadata)
|
||||
# Apply min_p in place.
|
||||
if sampling_metadata.min_p is not None:
|
||||
apply_min_p(logits, sampling_metadata.min_p)
|
||||
# Apply top_k and/or top_p. This might return a new tensor.
|
||||
logits = apply_top_k_top_p(logits, sampling_metadata.top_k,
|
||||
sampling_metadata.top_p)
|
||||
|
||||
sampled = gumbel_sample(
|
||||
logits,
|
||||
sampling_metadata.temperature,
|
||||
sampling_metadata.seeds,
|
||||
sampling_metadata.pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
return sampled, logits
|
||||
@@ -1,8 +1,28 @@
|
||||
# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/states.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu.states import RequestState, UvaBuffer
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
|
||||
|
||||
class AscendRequestState(RequestState):
|
||||
@@ -18,16 +38,15 @@ class AscendRequestState(RequestState):
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
with uva_wrapper():
|
||||
super().__init__(
|
||||
max_num_reqs,
|
||||
max_model_len,
|
||||
max_num_batched_tokens,
|
||||
num_speculative_steps,
|
||||
vocab_size,
|
||||
device,
|
||||
pin_memory,
|
||||
)
|
||||
super().__init__(
|
||||
max_num_reqs,
|
||||
max_model_len,
|
||||
max_num_batched_tokens,
|
||||
num_speculative_steps,
|
||||
vocab_size,
|
||||
device,
|
||||
pin_memory,
|
||||
)
|
||||
# because we will override these attribute, delete these attribute to
|
||||
# make sure it's collected by python gc immediately.
|
||||
del self.prefill_token_ids
|
||||
@@ -78,11 +97,9 @@ def uva_wrapper():
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
# TODO(Ronald1995): rectify this when NPU support uva.
|
||||
global UvaBuffer
|
||||
ori_class = UvaBuffer
|
||||
try:
|
||||
UvaBuffer = UvaBufferWrapper
|
||||
# TODO(Ronald1995): rectify this when NPU support uva.
|
||||
vllm.v1.worker.gpu.states.UvaBuffer = UvaBufferWrapper
|
||||
yield
|
||||
finally:
|
||||
UvaBuffer = ori_class
|
||||
pass
|
||||
|
||||
@@ -5,29 +5,16 @@ import torch
|
||||
|
||||
@contextmanager
|
||||
def torch_cuda_wrapper():
|
||||
ori_event = torch.cuda.Event
|
||||
ori_stream = torch.cuda.Stream
|
||||
ori_default_stream = torch.cuda.default_stream
|
||||
ori_current_stream = torch.cuda.current_stream
|
||||
ori_graph_pool_handle = torch.cuda.graph_pool_handle
|
||||
ori_cuda_graph_cls = torch.cuda.CUDAGraph
|
||||
ori_cuda_graph_func = torch.cuda.graph
|
||||
try:
|
||||
torch.cuda.Event = torch.npu.Event
|
||||
torch.cuda.Stream = torch.npu.Stream
|
||||
torch.cuda.stream = torch.npu.stream
|
||||
torch.cuda.default_stream = torch.npu.default_stream
|
||||
torch.cuda.current_stream = torch.npu.current_stream
|
||||
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
|
||||
torch.cuda.CUDAGraph = torch.npu.NpuGraph
|
||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||
torch.cuda.graph = torch.npu.graph
|
||||
torch.cuda.synchronize = torch.npu.synchronize
|
||||
yield
|
||||
finally:
|
||||
# revert back torch cuda properties, so it will still raise error
|
||||
# to call cuda ops in npu environment.
|
||||
torch.cuda.Event = ori_event
|
||||
torch.cuda.Stream = ori_stream
|
||||
torch.cuda.default_stream = ori_default_stream
|
||||
torch.cuda.current_stream = ori_current_stream
|
||||
torch.cuda.graph_pool_handle = ori_graph_pool_handle
|
||||
torch.cuda.CUDAGraph = ori_cuda_graph_cls
|
||||
torch.cuda.graph = ori_cuda_graph_func
|
||||
pass
|
||||
|
||||
@@ -233,8 +233,8 @@ class NPUWorker(WorkerBase):
|
||||
self.device = self._init_device()
|
||||
# Init ModelRunner here, so that we have access to self.device.
|
||||
if self.use_v2_model_runner:
|
||||
logger.error(
|
||||
"npu model runner v2 is in developing, it can't work well for now."
|
||||
logger.warning(
|
||||
"npu model runner v2 is in developing, some features doesn't work for now."
|
||||
)
|
||||
from vllm_ascend.worker.v2.model_runner import \
|
||||
NPUModelRunner as NPUModelRunnerV2
|
||||
|
||||
Reference in New Issue
Block a user