Files
xc-llm-ascend/vllm_ascend/kv_offload/cpu_npu.py
meihanc 922e5c163b [main2main] upgrade vllm main 0202 (#6560)
### What this PR does / why we need it?
1. Fix `TypeError: FusedMoEParallelConfig.__init__() missing 1 required
positional argument: 'is_sequence_parallel'` due to
https://github.com/vllm-project/vllm/pull/32567
2. Fix ` TypeError: '>' not supported between instances of 'MagicMock'
and 'int'` due to https://github.com/vllm-project/vllm/pull/33035
3. Fix `TypeError: Can't instantiate abstract class AscendMLAImpl with
abstract methods forward_mha, forward_mqa` and AttributeError: 'bool'
object has no attribute 'process_weights_after_loading' due to
https://github.com/vllm-project/vllm/pull/33284
4. Fix `'AscendSharedFusedMoE' object has no attribute
'_routed_input_transform'`due to
https://github.com/vllm-project/vllm/pull/32790
5. Fix `NPUModelRunner._dummy_run() got an unexpected keyword argument
'num_active_loras'` due to
https://github.com/vllm-project/vllm/pull/32005
6. Fix the problem caused by` 'tuple' object has no attribute 'job_id'`
due to https://github.com/vllm-project/vllm/pull/27492
7. Fix the problem that all_moe_layers is not equal to vllm.moe_forward,
vllm.moe_forward_shared due to
https://github.com/vllm-project/vllm/pull/33184
8. Add patch to fix the problem "got multiple values for keyword
argument 'add_special_tokens'" due to
https://github.com/vllm-project/vllm/pull/32863
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: hfadzxy <starmoon_zhang@163.com>
2026-02-05 19:31:17 +08:00

193 lines
7.1 KiB
Python

import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import AttentionBackend # type: ignore
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.worker import OffloadingHandler, TransferResult, TransferSpec
from vllm_ascend.utils import vllm_version_is
logger = init_logger(__name__)
def expand_block_ids(
block_ids: np.ndarray,
block_size_factor: int,
output: np.ndarray,
skip_count: int = 0,
):
"""
Convert a list of block IDs to a list of matching block ids,
assuming each block is composed of actual block_size_factor blocks.
Outputs to output tensor.
The first skip_count blocks will be skipped.
Note that skip_count must be less than block_size_factor.
For example, if block_ids = [0, 1, 3] and block_size_factor = 4,
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
since 0 maps to [0, 1, 2, 3]
1 maps to [4, 5, 6, 7]
and 3 maps to [12, 13, 14, 15]
"""
assert skip_count < block_size_factor
first_range = np.arange(skip_count, block_size_factor)
full_range = np.arange(0, block_size_factor)
output_idx = 0
for i, block_id in enumerate(block_ids):
base_block_id = block_id * block_size_factor
indices = first_range if i == 0 else full_range
output_end_idx = output_idx + len(indices)
output[output_idx:output_end_idx] = base_block_id + indices
output_idx = output_end_idx
class CpuNpuOffloadingHandler(OffloadingHandler):
def __init__(
self,
gpu_block_size: int,
cpu_block_size: int,
num_cpu_blocks: int,
gpu_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
assert cpu_block_size % gpu_block_size == 0
self.block_size_factor = cpu_block_size // gpu_block_size
# npu streams for npu->cpu and cpu->npu
self.d2h_stream = torch.npu.Stream()
self.h2d_stream = torch.npu.Stream()
# job_id -> transfer npu event
self.transfer_events: dict[int, torch.npu.Event] = {}
# list of npu events available for reuse
self.events_pool: list[torch.npu.Event] = []
pin_memory = is_pin_memory_available()
# allocate cpu tensors
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
self.npu_tensors: list[torch.Tensor] = []
self.cpu_tensors: list[torch.Tensor] = []
for layer_name, gpu_tensor in gpu_caches.items():
self.npu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor[0].shape
num_blocks_idx = 0
cpu_shape = list(gpu_shape)
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
self.cpu_tensors.append(
(
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
torch.zeros(
cpu_shape,
dtype=gpu_tensor[0].dtype,
device="cpu",
pin_memory=pin_memory,
),
)
)
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
logger.info("start transfer_async...")
src_spec, dst_spec = spec
if isinstance(src_spec, CPULoadStoreSpec):
assert isinstance(dst_spec, GPULoadStoreSpec)
stream = self.h2d_stream
src_tensors = self.cpu_tensors
dst_tensors = self.npu_tensors
src_block_size_factor = self.block_size_factor
dst_block_size_factor = 1
else:
assert isinstance(src_spec, GPULoadStoreSpec)
assert isinstance(dst_spec, CPULoadStoreSpec)
stream = self.d2h_stream
src_tensors = self.npu_tensors
dst_tensors = self.cpu_tensors
src_block_size_factor = 1
dst_block_size_factor = self.block_size_factor
src_blocks = src_spec.block_ids
dst_blocks = dst_spec.block_ids
assert src_blocks.ndim == 1
assert dst_blocks.ndim == 1
dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor
src_sub_block_count = src_blocks.size * src_block_size_factor
assert src_sub_block_count == dst_blocks.size * dst_block_size_factor - dst_sub_blocks_to_skip
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
expand_block_ids(
dst_blocks,
dst_block_size_factor,
src_to_dst[:, 1],
skip_count=dst_sub_blocks_to_skip,
)
src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop() if self.events_pool else torch.npu.Event()
with torch.npu.stream(stream):
for src_tensor, dst_tensor in zip(src_tensors, dst_tensors):
src_key_cache, src_value_cache = src_tensor[0], src_tensor[1]
dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1]
torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
torch.ops._C_ascend.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
event.record(stream)
self.transfer_events[job_id] = event
# success
return True
def get_finished(self) -> list[TransferResult]:
results: list[TransferResult] = []
if vllm_version_is("v0.15.0"):
for job_id, event in self.transfer_events.items():
if event.query():
results.append((job_id, True))
self.events_pool.append(event)
for job_id, _ in results:
del self.transfer_events[job_id]
else:
finished_job_ids = []
for job_id, event in self.transfer_events.items():
if event.query():
results.append(
TransferResult(
job_id=job_id,
success=True,
transfer_size=None,
transfer_time=None,
transfer_type=None,
)
)
finished_job_ids.append(job_id)
self.events_pool.append(event)
for job_id in finished_job_ids:
del self.transfer_events[job_id]
return results
def wait(self, job_ids: set[int]) -> None:
"""
Wait (block) until all specified transfer jobs are completed.
"""
for job_id in job_ids:
event = self.transfer_events.get(job_id)
if event is not None:
# This will block until the NPU event is complete
event.synchronize()