implement model runner v2 basic framework (#5051)

### What this PR does / why we need it?
This PR aim to implement model runner v2 basic framework in vllm-ascend,
the e2e function is not guaranteed by this pr.
 
### Does this PR introduce _any_ user-facing change?
use envs.VLLM_USE_V2_MODEL_RUNNER to decide if choose model_runenr_v2.

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2025-12-18 15:51:54 +08:00
committed by GitHub
parent 1c8c23de58
commit b69b04d3a9
16 changed files with 843 additions and 98 deletions

View File

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.cudagraph_utils import \
prepare_inputs_to_capture as prepare_inputs_to_capture_gpu
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
class AclGraphManager(CudaGraphManager):
"""ACL Graph Manager for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper():
super().__init__(vllm_config, device)
def capture_graph(
self,
num_tokens: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
with (torch_cuda_wrapper(), prepare_capture_inputs_wrapper()):
super().capture_graph(
num_tokens,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
@contextmanager
def prepare_capture_inputs_wrapper():
"""Context manager to override input preparation for NPU graph capture."""
# TODO(Ronald1995): make prepare_inputs_to_capture as static method
# in CudaGraphManager.
global prepare_inputs_to_capture_gpu
try:
ori_func = prepare_inputs_to_capture_gpu
prepare_inputs_to_capture_gpu = prepare_inputs_to_capture
yield
finally:
prepare_inputs_to_capture_gpu = ori_func
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
# TODO(Ronald1995): Implement NPU specific input preparation.
return {}

View File

@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Any
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.config.model import ModelDType
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, KVCacheConfig
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
AscendPrefillContextParallelMetadata)
_ATTENTION_MASK_BUILDER = None
def get_attn_mask_builder(device: torch.device):
"""Get attention mask builder which only have one instance."""
global _ATTENTION_MASK_BUILDER
if _ATTENTION_MASK_BUILDER is None:
_ATTENTION_MASK_BUILDER = AttentionMaskBuilder(device)
return _ATTENTION_MASK_BUILDER
def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int,
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
num_computed_tokens_cpu: torch.Tensor,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
decode_token_per_req: int,
actual_seq_lengths_q: list[int],
positions: torch.Tensor | None = None,
attn_mask: torch.Tensor
| None = None,
spec_attn_mask: torch.Tensor | None = None,
attn_state: Any | None = None,
is_only_prefill: bool = False,
graph_pad_size: int = -1,
num_input_tokens: int = 0,
prefill_context_parallel_metadata: AscendPrefillContextParallelMetadata
| None = None,
) -> dict[str, Any]:
"""Build attention metadata for Ascend NPUs."""
# TODO(Ronald1995): optimize AscendCommonAttentionMetadata.
max_query_len = int(query_start_loc_cpu.max())
attn_metadata: dict[str, Any] = {}
kv_cache_groups = kv_cache_config.kv_cache_groups
for i, kv_cache_spec in enumerate(kv_cache_groups):
block_table = block_tables[i]
slot_mapping = slot_mappings[i]
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=seq_lens_cpu[:num_reqs],
seq_lens=seq_lens[:num_reqs],
num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
decode_token_per_req=decode_token_per_req,
block_table_tensor=block_table,
slot_mapping=slot_mapping,
actual_seq_lengths_q=actual_seq_lengths_q,
positions=positions,
attn_mask=attn_mask,
spec_attn_mask=spec_attn_mask,
attn_state=attn_state,
is_only_prefill=is_only_prefill,
graph_pad_size=graph_pad_size,
num_input_tokens=num_input_tokens,
prefill_context_parallel_metadata=prefill_context_parallel_metadata,
)
attn_metadata_builder = attn_metadata_builders[i]
metadata = attn_metadata_builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata, # type: ignore
)
for layer_name in kv_cache_spec.layer_names:
attn_metadata[layer_name] = metadata
return attn_metadata
def build_attn_state(
vllm_config: VllmConfig,
seq_lens_np: np.ndarray,
num_reqs,
num_scheduled_tokens,
num_valid_tokens,
):
"""Build attention state for npu's attention backend."""
if vllm_config.model_config.runner_type == "pooling":
if isinstance(
vllm_config.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
EncoderOnlyAttentionSpec,
):
attn_state = AscendAttentionState.PrefillNoCache
else:
attn_state = AscendAttentionState.PrefillCacheHit
elif np.array_equal(seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs
# but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if (vllm_config.speculative_config
and vllm_config.speculative_config.method == 'mtp'):
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding
# need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if (vllm_config.speculative_config
and vllm_config.speculative_config.method == 'mtp'):
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
# splitfuse
elif vllm_config.scheduler_config.enable_chunked_prefill:
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
return attn_state
def make_attention_mask(
vllm_config: VllmConfig,
attn_state: AscendAttentionState,
dtype: ModelDType | torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""make attention mask for npu's attention backend."""
attn_mask_builder = get_attn_mask_builder(device)
# pcp situation.
if attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
if vllm_config.model_config.runner_type == "pooling":
return attn_mask_builder.get_attn_mask(2048, torch.bool)
# TODO(Ronald1995) cosidering pcp.
if vllm_config.model_config.use_mla:
# mla prefill
if attn_state != AscendAttentionState.DecodeOnly:
return attn_mask_builder.get_mla_mask(dtype)
return attn_mask_builder.get_splitfuse_attn_mask()

View File

@@ -0,0 +1,37 @@
import numpy as np
import torch
from vllm.v1.worker.gpu.input_batch import InputBuffers
class AscendInputBuffers(InputBuffers):
"""Input buffers for Ascend NPUs."""
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
inputs_embeds_size: int,
vocab_size: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
super().__init__(
max_num_reqs,
max_num_tokens,
inputs_embeds_size,
vocab_size,
dtype,
device,
pin_memory,
)
# Create seq_lens_cpu and seq_lens_np.
# npu's attention backend still needs seq_lens on CPU side.
self.seq_lens_cpu: torch.Tensor = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device="cpu",
)
# seq_len_np and seq_lens_cpu share the same memory.
# define seq_lens_np for easier calculation with numpy.
self.seq_lens_np: np.ndarray = self.seq_lens_cpu.numpy()

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.worker.gpu.input_batch import (InputBatch,
combine_sampled_and_draft_tokens,
prepare_pos_seq_lens,
prepare_prefill_inputs)
from vllm.v1.worker.gpu.model_runner import GPUModelRunner
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm_ascend.worker.v2.aclgraph_utils import AclGraphManager
from vllm_ascend.worker.v2.attn_utils import (build_attn_metadata,
build_attn_state,
make_attention_mask)
from vllm_ascend.worker.v2.input_batch import AscendInputBuffers
from vllm_ascend.worker.v2.states import AscendRequestState
from vllm_ascend.worker.v2.utils import torch_cuda_wrapper
logger = init_logger(__name__)
class NPUModelRunner(GPUModelRunner):
"""Model runner for Ascend NPUs."""
def __init__(self, vllm_config: VllmConfig, device: torch.device):
with torch_cuda_wrapper():
super().__init__(vllm_config, device)
# because we will override these attribute, delete these attribute to
# make sure it's collected by python gc immediately.
del self.cudagraph_manager
del self.req_states
del self.input_buffers
# NPU specific initializations can be added below.
self.cudagraph_manager: AclGraphManager = AclGraphManager(
vllm_config,
device,
)
# AscendRequestState has extra `num_computed_tokens_cpu` attribute.
# so reinitialize req_states here.
self.req_states: AscendRequestState = AscendRequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
)
# AscendInputBuffers has extra `seq_lens_cpu` attribute.
# so reinitialize input_buffers here.
self.input_buffers: AscendInputBuffers = AscendInputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
inputs_embeds_size=self.inputs_embeds_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
# actual seq lengths for query (used in attention backends).
self.actual_seq_lengths_q: list[int] = []
# decode token per request (used in attention backends).
self.decode_token_per_req = 1
# there attributes are for async scheduling with speculative decoding.
# because npu attention backend still need to use seq_lens_cpu,
# we need to copy num_rejected_tokens back to cpu to help
# update actual seq_lens_cpu. gpu attention backend do not need these
# attributes, cause their attention backends do not use seq_lens_cpu.
# and seq_lens_cpu is deprecated in gpu_model_runner_v2.
self.num_rejected_tokens_event = None
self.num_rejectd_tokens_cpu = None
self.num_rejected_token_stream = None
if self.use_async_scheduling and self.do_spec_decode:
self.num_rejected_tokens_event = torch.npu.Event()
self.num_rejected_token_stream = torch.npu.Stream()
self.num_rejectd_tokens_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory,
)
def prepare_inputs(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> InputBatch:
"""Override GPUModelRunner.prepare_inputs for Ascend NPUs.
npu attention backends need seq_lens_cpu to work.
so we need to prepare seq_lens_cpu here.
"""
num_tokens = scheduler_output.total_num_scheduled_tokens
assert num_tokens > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(
scheduler_output.num_scheduled_tokens.keys(),
key=lambda k: scheduler_output.num_scheduled_tokens[k],
)
self._update_seq_lens_cpu(scheduler_output, req_ids)
num_scheduled_tokens = np.array(
[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)
num_valid_tokens = num_scheduled_tokens
if scheduler_output.scheduled_spec_decode_tokens:
num_valid_tokens = np.array(
[
num_tokens - len(
scheduler_output.scheduled_spec_decode_tokens.get(
i, []))
for num_tokens, i in zip(num_scheduled_tokens, req_ids)
],
dtype=np.int32,
)
attn_state = build_attn_state(
self.vllm_config,
self.input_buffers.seq_lens_np,
num_reqs,
num_scheduled_tokens,
num_valid_tokens,
)
attn_mask = make_attention_mask(
self.vllm_config,
attn_state,
self.dtype,
self.device,
)
idx_mapping_list = [
self.req_states.req_id_to_index[req_id] for req_id in req_ids
]
idx_mapping = self.input_buffers.idx_mapping
idx_mapping.np[:num_reqs] = idx_mapping_list
idx_mapping_np = idx_mapping.np[:num_reqs]
# add `idx_mapping_cpu` here, because vllm-ascend's self.req_states.
# num_computed_tokens_cpu is actually cpu's tensor, while it's a gpu's
# tensor in vllm gpu_model_runner_v2.
idx_mapping_cpu = idx_mapping.cpu[:num_reqs]
idx_mapping_npu = idx_mapping.copy_to_gpu(num_reqs)
# Get the number of draft tokens for each request.
if not scheduler_output.scheduled_spec_decode_tokens:
# No draft token scheduled (common case).
total_num_draft_tokens = 0
total_num_logits = num_reqs
cu_num_logits = torch.arange(num_reqs + 1,
device=self.device,
dtype=torch.int32)
else:
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
num_draft_tokens = np.array(
[
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
for req_id in req_ids
],
dtype=np.int32,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
np.cumsum(
num_draft_tokens + 1,
out=self.input_buffers.cu_num_logits.np[1:num_reqs + 1],
)
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(
num_reqs + 1)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping_npu)
# Get query_start_loc.
np.cumsum(
num_scheduled_tokens,
out=self.input_buffers.query_start_loc.np[1:num_reqs + 1],
)
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
self.input_buffers.query_start_loc.np[num_reqs + 1:] = num_tokens
self.input_buffers.query_start_loc.copy_to_gpu()
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[:
num_reqs +
1]
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[:
num_reqs +
1]
query_start_loc_np = self.input_buffers.query_start_loc.np[:num_reqs +
1]
# Get prefill tokens.
prepare_prefill_inputs(
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping_npu,
query_start_loc_gpu,
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens,
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
idx_mapping_npu,
query_start_loc_gpu,
self.req_states.num_computed_tokens,
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids,
idx_mapping_npu,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens,
self.req_states.prefill_len.gpu,
self.req_states.draft_tokens,
cu_num_logits,
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, self.input_buffers.positions[:num_tokens])
# Layer name -> attention metadata.
# TODO(Ronald1995): try to add a new method `build_attn_metadata` in
# vllm gpu_model_runner_v2, maybe we don't overwrite `prepare_inputs`
# method like this.
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens,
seq_lens_cpu=self.input_buffers.seq_lens_cpu,
actual_seq_lengths_q=self.actual_seq_lengths_q,
num_computed_tokens_cpu=self.req_states.
num_computed_tokens_cpu[idx_mapping_cpu],
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
decode_token_per_req=self.decode_token_per_req,
attn_mask=attn_mask,
attn_state=attn_state,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping_npu,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_np=self.input_buffers.seq_lens_np,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
"""Override GPUModelRunner.sample for Ascend NPUs.
when using async scheduling with speculative decoding,
we need to copy mpu's num_rejected tensor to cpu.
these operations aren't needed in gpu_model_runner_v2,
because gpu attention backends do not use seq_lens_cpu anymore.
"""
sampler_output, num_sampled, num_rejected = super().sample(
hidden_states,
input_batch,
sampling_metadata,
grammar_output,
)
if self.num_rejected_tokens_event is not None:
# npu attention backend still need to use seq_lens_cpu,
# when doing speculative decoding with async_scheduling,
# we need to copy num_rejected_tokens back to cpu.
default_stream = torch.cuda.current_stream()
assert self.num_rejected_token_stream is not None
assert self.num_rejectd_tokens_cpu is not None
with torch.npu.stream(self.num_rejected_token_stream):
self.num_rejected_token_stream.wait_stream(default_stream)
self.num_rejectd_tokens_cpu.copy_(
num_rejected,
non_blocking=True,
)
self.num_rejected_tokens_event.record()
return sampler_output, num_sampled, num_rejected
def _update_seq_lens_cpu(
self,
scheduler_output: SchedulerOutput,
req_ids: list[str],
):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
# update num_computed_tokens_cpu
# TODO(Ronald1995): update num_computed_tokens_cpu by considering
# num_rejectd_tokens.
for req_id, num_computed_token in zip(
scheduler_output.scheduled_cached_reqs.req_ids,
scheduler_output.scheduled_cached_reqs.num_computed_tokens,
):
req_index = self.req_states.req_id_to_index[req_id]
self.req_states.num_computed_tokens_cpu[
req_index] = num_computed_token
# update seq_lens_cpu
for i, req_id in enumerate(req_ids):
req_index = self.req_states.req_id_to_index[req_id]
num_computed_tokens = self.req_states.num_computed_tokens_cpu[
req_index]
self.input_buffers.seq_lens_cpu[
i] = num_computed_tokens + num_scheduled_tokens[req_id]

View File

@@ -0,0 +1,88 @@
from contextlib import contextmanager
import torch
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu.states import RequestState, UvaBuffer
class AscendRequestState(RequestState):
"""Request state for Ascend NPUs."""
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
pin_memory: bool,
):
with uva_wrapper():
super().__init__(
max_num_reqs,
max_model_len,
max_num_batched_tokens,
num_speculative_steps,
vocab_size,
device,
pin_memory,
)
# because we will override these attribute, delete these attribute to
# make sure it's collected by python gc immediately.
del self.prefill_token_ids
# vllm gpu_model_runner_v2 deprecate the seqs_lens_cpu attribute,
# because they think most attention backends do not need it.
# However, Ascend attention backend muse uses seqs_lens_cpu,
# so we keep num_computed_tokens_cpu here, seq_lens_cpu need to be
# calculated by num_computed_tokens_cpu + decode_token_per_req outside.
self.num_computed_tokens_cpu: torch.Tensor = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
)
# NOTE(Ronald1995): Ascend NPUs do not support UVA yet,
# so we use CpuGpuBuffer to allocate prefill_token_ids buffer.
self.prefill_token_ids: CpuGpuBuffer = self._make_buffer( # type: ignore
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32)
def add_request(
self,
req_id,
prompt_len,
prefill_token_ids,
num_computed_tokens,
sampling_params,
lora_request,
):
super().add_request(
req_id,
prompt_len,
prefill_token_ids,
num_computed_tokens,
sampling_params,
lora_request,
)
req_idx = self.req_id_to_index[req_id]
self.num_computed_tokens_cpu[req_idx] = num_computed_tokens
@contextmanager
def uva_wrapper():
"""Context manager to disable UVA for Ascend NPUs."""
class UvaBufferWrapper:
def __init__(self, *args, **kwargs):
pass
# TODO(Ronald1995): rectify this when NPU support uva.
global UvaBuffer
ori_class = UvaBuffer
try:
UvaBuffer = UvaBufferWrapper
yield
finally:
UvaBuffer = ori_class

View File

@@ -0,0 +1,33 @@
from contextlib import contextmanager
import torch
@contextmanager
def torch_cuda_wrapper():
ori_event = torch.cuda.Event
ori_stream = torch.cuda.Stream
ori_default_stream = torch.cuda.default_stream
ori_current_stream = torch.cuda.current_stream
ori_graph_pool_handle = torch.cuda.graph_pool_handle
ori_cuda_graph_cls = torch.cuda.CUDAGraph
ori_cuda_graph_func = torch.cuda.graph
try:
torch.cuda.Event = torch.npu.Event
torch.cuda.Stream = torch.npu.Stream
torch.cuda.default_stream = torch.npu.default_stream
torch.cuda.current_stream = torch.npu.current_stream
torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle
torch.cuda.CUDAGraph = torch.npu.NpuGraph
torch.cuda.graph = torch.npu.graph
yield
finally:
# revert back torch cuda properties, so it will still raise error
# to call cuda ops in npu environment.
torch.cuda.Event = ori_event
torch.cuda.Stream = ori_stream
torch.cuda.default_stream = ori_default_stream
torch.cuda.current_stream = ori_current_stream
torch.cuda.graph_pool_handle = ori_graph_pool_handle
torch.cuda.CUDAGraph = ori_cuda_graph_cls
torch.cuda.graph = ori_cuda_graph_func