[Feature] adapt to uva buffer and main2main (#6657)
### What this PR does / why we need it?
vllm model runner v2 use uva buffer to prepare input data, but npu
doesn't support uva yet, this pr implement a uvawrapper class to mimic
gpu's uva backend. what's more, this pr make some modifications to adapt
to the newer main branch.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM main:
13397841ab
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -551,7 +551,7 @@ class NPUPlatform(Platform):
|
||||
vllm_config: VllmConfig,
|
||||
dp_metadata,
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int | None = None,
|
||||
num_tokens: int = 0,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
cudagraph_runtime_mode=None,
|
||||
batch_descriptor=None,
|
||||
@@ -601,10 +601,6 @@ class NPUPlatform(Platform):
|
||||
if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER:
|
||||
return {}
|
||||
|
||||
num_actual_tokens = list(attn_metadata.values())[0].num_actual_tokens
|
||||
if num_tokens is None:
|
||||
num_tokens = num_actual_tokens
|
||||
|
||||
moe_comm_type = select_moe_comm_method(
|
||||
num_tokens,
|
||||
vllm_config,
|
||||
@@ -636,10 +632,13 @@ class NPUPlatform(Platform):
|
||||
|
||||
# TODO(Levi-JQ): another PR to normalize the enabling logic for sp/fc2
|
||||
flashcomm_v2_enabled = flashcomm2_enable() and tp_world_size > 1 and num_tokens is not None
|
||||
pad_size = 0
|
||||
pad_size = None
|
||||
padded_length = None
|
||||
if sp_enabled or flashcomm_v2_enabled:
|
||||
pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
num_tokens = list(attn_metadata.values())[0].num_actual_tokens
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and dp_metadata is not None:
|
||||
max_tokens_across_dp = dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
@@ -648,8 +647,9 @@ class NPUPlatform(Platform):
|
||||
pad_size = padded_length - num_tokens
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
mc2_mask = None
|
||||
if num_tokens is not None:
|
||||
num_actual_tokens = num_tokens
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
padded_num_tokens = math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
reserved_mc2_mask = get_mc2_mask()
|
||||
|
||||
Reference in New Issue
Block a user