[Feature] adapt to uva buffer and main2main (#6657)

### What this PR does / why we need it?
vllm model runner v2 use uva buffer to prepare input data, but npu
doesn't support uva yet, this pr implement a uvawrapper class to mimic
gpu's uva backend. what's more, this pr make some modifications to adapt
to the newer main branch.

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM main:
13397841ab

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-02-12 10:36:31 +08:00
committed by GitHub
parent 56269eae0e
commit f1ffb5fb19
14 changed files with 407 additions and 179 deletions

View File

@@ -551,7 +551,7 @@ class NPUPlatform(Platform):
vllm_config: VllmConfig,
dp_metadata,
virtual_engine: int = 0,
num_tokens: int | None = None,
num_tokens: int = 0,
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode=None,
batch_descriptor=None,
@@ -601,10 +601,6 @@ class NPUPlatform(Platform):
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
return {}
num_actual_tokens = list(attn_metadata.values())[0].num_actual_tokens
if num_tokens is None:
num_tokens = num_actual_tokens
moe_comm_type = select_moe_comm_method(
num_tokens,
vllm_config,
@@ -636,10 +632,13 @@ class NPUPlatform(Platform):
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
pad_size = 0
pad_size = None
padded_length = None
if sp_enabled or flashcomm_v2_enabled:
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
if num_tokens is None and attn_metadata is not None:
num_tokens = list(attn_metadata.values())[0].num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and dp_metadata is not None:
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
@@ -648,8 +647,9 @@ class NPUPlatform(Platform):
pad_size = padded_length - num_tokens
else:
max_tokens_across_dp = num_tokens
mc2_mask = None
if num_tokens is not None:
num_actual_tokens = num_tokens
# NOTE: token num which need to pad to when mc2
padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
reserved_mc2_mask = get_mc2_mask()