refact runner model v1 (#2461)
refact model runner v1
### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1
- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -17,21 +17,29 @@
|
|||||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import types
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
import vllm.envs as envs_vllm
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||||
check_torchair_cache_exist,
|
check_torchair_cache_exist,
|
||||||
register_torchair_model,
|
register_torchair_model,
|
||||||
write_kv_cache_bytes_to_file)
|
write_kv_cache_bytes_to_file)
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
maybe_converting_weight_acl_format)
|
is_310p, maybe_converting_weight_acl_format)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
@@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.new_kv_cache_bytes = -1
|
||||||
|
self.torchair_compiled_model = None # type: ignore
|
||||||
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
|
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||||
|
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||||
|
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||||
|
self.init_torchair_graph_batch_sizes()
|
||||||
|
|
||||||
|
self.check_torchair_graph_batch_sizes()
|
||||||
|
|
||||||
|
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||||
|
self.torchair_graph_batch_sizes)
|
||||||
|
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||||
|
torch._logging.set_logs(
|
||||||
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||||
|
|
||||||
|
self._check_batch_sizes_consistency()
|
||||||
register_torchair_model()
|
register_torchair_model()
|
||||||
|
|
||||||
def _get_forward_metadata_across_dp_and_pad(
|
def _get_forward_metadata_across_dp_and_pad(
|
||||||
@@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
if self.new_kv_cache_bytes > 0:
|
if self.new_kv_cache_bytes > 0:
|
||||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||||
self.new_kv_cache_bytes)
|
self.new_kv_cache_bytes)
|
||||||
|
|
||||||
|
def _use_aclgraph(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_batch_sizes_consistency(self) -> None:
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return
|
||||||
|
|
||||||
|
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int32)
|
||||||
|
gathered_graph_batch_size = local.clone()
|
||||||
|
dist.all_reduce(gathered_graph_batch_size,
|
||||||
|
group=get_dp_group().cpu_group)
|
||||||
|
expected = local * self.dp_size
|
||||||
|
|
||||||
|
if not torch.equal(gathered_graph_batch_size, expected):
|
||||||
|
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||||
|
as_tuple=False).flatten().tolist()
|
||||||
|
raise AssertionError(
|
||||||
|
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||||
|
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||||
|
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||||
|
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||||
|
if not with_prefill:
|
||||||
|
self.graph_pad_size = graph_pad_size
|
||||||
|
else:
|
||||||
|
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||||
|
|
||||||
|
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||||
|
num_input_tokens, with_prefill,
|
||||||
|
padded_num_tokens_across_dp):
|
||||||
|
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||||
|
input_ids, positions = super()._update_input_ids_and_positions(
|
||||||
|
input_ids, positions, num_input_tokens, with_prefill,
|
||||||
|
padded_num_tokens_across_dp)
|
||||||
|
|
||||||
|
if not with_prefill:
|
||||||
|
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||||
|
positions = self.positions[:padded_num_tokens_across_dp]
|
||||||
|
return input_ids, positions
|
||||||
|
|
||||||
|
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||||
|
padded_num_tokens_across_dp,
|
||||||
|
input_ids, positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds):
|
||||||
|
model_kwargs = {
|
||||||
|
"kv_caches": self.kv_caches,
|
||||||
|
"attn_metadata": attn_metadata
|
||||||
|
}
|
||||||
|
if not with_prefill:
|
||||||
|
maybe_converting_weight_acl_format(self.model,
|
||||||
|
ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
|
padded_num_tokens_across_dp)
|
||||||
|
hidden_states = compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.model is not None
|
||||||
|
maybe_converting_weight_acl_format(self.model,
|
||||||
|
ACL_FORMAT_FRACTAL_ND)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||||
|
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_model = self.torchair_compiled_models.get(
|
||||||
|
batch_size
|
||||||
|
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||||
|
|
||||||
|
if compiled_model:
|
||||||
|
return compiled_model
|
||||||
|
|
||||||
|
import torchair # type: ignore
|
||||||
|
from torchair import patch_for_hcom # type: ignore
|
||||||
|
|
||||||
|
patch_for_hcom()
|
||||||
|
|
||||||
|
if is_310p():
|
||||||
|
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||||
|
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||||
|
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||||
|
communication_adaptation_310p
|
||||||
|
communication_adaptation_310p()
|
||||||
|
|
||||||
|
config = torchair.CompilerConfig()
|
||||||
|
config.experimental_config.frozen_parameter = True
|
||||||
|
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||||
|
# disable it on 300I Duo platform now.
|
||||||
|
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||||
|
config.experimental_config.enable_view_optimize = \
|
||||||
|
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
if not self.use_cached_npu_graph:
|
||||||
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
|
self.torchair_compiled_model = torch.compile(
|
||||||
|
self.model,
|
||||||
|
dynamic=True,
|
||||||
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
backend=npu_backend)
|
||||||
|
return self.torchair_compiled_model
|
||||||
|
else:
|
||||||
|
# Generate a new forward proxy code object to prevent the invalidation of
|
||||||
|
# compilation cache caused by dynamo retracing
|
||||||
|
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||||
|
forward_fn = self.model.forward
|
||||||
|
code = forward_fn.__code__
|
||||||
|
# Mark code object with a new proxy name
|
||||||
|
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||||
|
|
||||||
|
modified_func = types.FunctionType(modified_code,
|
||||||
|
forward_fn.__globals__,
|
||||||
|
name=forward_proxy_name,
|
||||||
|
argdefs=forward_fn.__defaults__)
|
||||||
|
|
||||||
|
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||||
|
self.model, nn.Module)
|
||||||
|
self.torchair_compiled_models[
|
||||||
|
batch_size] = torchair.inference.cache_compile(
|
||||||
|
self.model.__dict__[forward_proxy_name],
|
||||||
|
dynamic=True,
|
||||||
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
|
config=config,
|
||||||
|
ge_cache=False)
|
||||||
|
return self.torchair_compiled_models[batch_size]
|
||||||
|
|
||||||
|
def init_torchair_graph_batch_sizes(self):
|
||||||
|
start_graph_batch_size = 4
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||||
|
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||||
|
|
||||||
|
while (start_graph_batch_size <= self.max_num_reqs):
|
||||||
|
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||||
|
start_graph_batch_size *= 2
|
||||||
|
|
||||||
|
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||||
|
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||||
|
if batch_size <= padded_batch_size:
|
||||||
|
# we treat batch_size as num of requests
|
||||||
|
return padded_batch_size
|
||||||
|
raise ValueError(
|
||||||
|
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||||
|
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_torchair_graph_batch_sizes(self):
|
||||||
|
# return graph_batch_sizes according to the max number of tokens
|
||||||
|
# first pad according to the number of requests
|
||||||
|
if len(self.torchair_graph_batch_sizes) == 0:
|
||||||
|
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||||
|
else:
|
||||||
|
self.torchair_graph_batch_sizes = sorted(
|
||||||
|
self.torchair_graph_batch_sizes)
|
||||||
|
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||||
|
self.torchair_graph_batch_sizes.pop()
|
||||||
|
if len(self.torchair_graph_batch_sizes) == 0:
|
||||||
|
logger.warning(
|
||||||
|
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||||
|
)
|
||||||
|
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||||
|
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||||
|
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||||
|
|
||||||
|
# padded max number tokens = max_num_req * decode_token_per_req
|
||||||
|
self.torchair_graph_batch_sizes = [
|
||||||
|
graph_batch_size * self.decode_token_per_req
|
||||||
|
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||||
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
|
if self.parallel_config.enable_expert_parallel:
|
||||||
|
new_graph_batch_sizes = []
|
||||||
|
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||||
|
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||||
|
1) // tp_size * tp_size
|
||||||
|
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||||
|
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||||
|
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||||
|
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||||
|
and self.decode_token_per_req > 1:
|
||||||
|
logger.warning(
|
||||||
|
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||||
|
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||||
|
)
|
||||||
|
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||||
|
|
||||||
|
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||||
|
return True
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import types
|
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
||||||
@@ -39,7 +38,6 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group)
|
has_kv_transfer_group)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||||
@@ -108,7 +106,6 @@ else:
|
|||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
import torch_npu
|
import torch_npu
|
||||||
import vllm.envs as envs_vllm
|
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
|
||||||
@@ -341,11 +338,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
self.use_aclgraph = (
|
self.use_aclgraph = self._use_aclgraph()
|
||||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
||||||
and self.compilation_config.level == CompilationLevel.PIECEWISE
|
|
||||||
and not self.model_config.enforce_eager
|
|
||||||
and not ascend_config.torchair_graph_config.enabled)
|
|
||||||
self.aclgraph_batch_sizes = list(
|
self.aclgraph_batch_sizes = list(
|
||||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||||
|
|
||||||
@@ -357,31 +350,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||||
torch.Tensor]] = None
|
torch.Tensor]] = None
|
||||||
|
|
||||||
self.new_kv_cache_bytes = -1
|
|
||||||
self.torchair_compiled_model = None # type: ignore
|
|
||||||
self.torchair_compiled_models = {} # type: ignore
|
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
||||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
|
||||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
|
||||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
|
||||||
self.init_torchair_graph_batch_sizes()
|
|
||||||
|
|
||||||
self.check_torchair_graph_batch_sizes()
|
|
||||||
|
|
||||||
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
|
|
||||||
self.graph_block_tables = np.zeros(
|
|
||||||
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
|
|
||||||
(self.model_config.max_model_len + self.block_size - 1) //
|
|
||||||
self.block_size),
|
|
||||||
dtype=np.int32)
|
|
||||||
|
|
||||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
|
||||||
self.torchair_graph_batch_sizes)
|
|
||||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
|
||||||
torch._logging.set_logs(
|
|
||||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
|
||||||
|
|
||||||
self.check_batch_sizes_consistency()
|
|
||||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||||
self.in_profile_run = False
|
self.in_profile_run = False
|
||||||
|
|
||||||
@@ -400,27 +368,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.moe_comm_method = AllGatherCommImpl
|
self.moe_comm_method = AllGatherCommImpl
|
||||||
|
|
||||||
def check_batch_sizes_consistency(self) -> None:
|
def _use_aclgraph(self) -> bool:
|
||||||
if not dist.is_initialized():
|
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||||
return
|
|
||||||
|
|
||||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int32)
|
|
||||||
gathered_graph_batch_size = local.clone()
|
|
||||||
dist.all_reduce(gathered_graph_batch_size,
|
|
||||||
group=get_dp_group().cpu_group)
|
|
||||||
expected = local * self.dp_size
|
|
||||||
|
|
||||||
if not torch.equal(gathered_graph_batch_size, expected):
|
|
||||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
|
||||||
as_tuple=False).flatten().tolist()
|
|
||||||
raise AssertionError(
|
|
||||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
|
||||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
|
||||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
|
||||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
# Remove finished requests from the cached states.
|
# Remove finished requests from the cached states.
|
||||||
@@ -1047,14 +996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||||
|
|
||||||
def _process_reqs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
||||||
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
|
AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
|
||||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
|
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
|
||||||
Optional[set[str]], Optional[set[str]]]:
|
Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
# Check input valid
|
# Check input valid
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
assert total_num_scheduled_tokens > 0
|
assert total_num_scheduled_tokens > 0
|
||||||
@@ -1103,9 +1053,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
logits_indices = cu_num_tokens - 1
|
|
||||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
|
||||||
non_blocking=True)
|
|
||||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||||
|
|
||||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||||
@@ -1118,7 +1065,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
self._calc_mrope_positions(scheduler_output)
|
self._calc_mrope_positions(scheduler_output)
|
||||||
|
|
||||||
if self.uses_mrope:
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||||
@@ -1127,7 +1073,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
|
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||||
self.positions[:total_num_scheduled_tokens].copy_(
|
self.positions[:total_num_scheduled_tokens].copy_(
|
||||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||||
positions = self.positions[:num_input_tokens]
|
|
||||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||||
|
|
||||||
self.seq_lens_np[:num_reqs] = (
|
self.seq_lens_np[:num_reqs] = (
|
||||||
@@ -1145,34 +1091,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_offsets,
|
block_offsets,
|
||||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||||
use_spec_decode = len(
|
num_valid_tokens)
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
|
||||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
|
||||||
attn_state = AscendAttentionState.PrefillNoCache
|
|
||||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
|
||||||
elif np.all(num_scheduled_tokens == 1):
|
|
||||||
attn_state = AscendAttentionState.DecodeOnly
|
|
||||||
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
|
||||||
# SpecDecoding now supports seq_len=1 and seq_len=2
|
|
||||||
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
|
||||||
attn_state = AscendAttentionState.SpecDecoding
|
|
||||||
# Speculative decoding.
|
|
||||||
elif np.all(num_valid_tokens == 1):
|
|
||||||
if self.use_eagle:
|
|
||||||
attn_state = AscendAttentionState.ChunkedPrefill
|
|
||||||
else:
|
|
||||||
attn_state = AscendAttentionState.SpecDecoding
|
|
||||||
# splitfuse
|
|
||||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
|
||||||
attn_state = AscendAttentionState.ChunkedPrefill
|
|
||||||
else:
|
|
||||||
attn_state = AscendAttentionState.PrefillCacheHit
|
|
||||||
|
|
||||||
self.attn_mask = self._make_attention_mask(
|
self.attn_mask = self._make_attention_mask(
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
query_lens=num_scheduled_tokens,
|
query_lens=num_scheduled_tokens,
|
||||||
position=positions,
|
position=self.positions[:num_input_tokens],
|
||||||
attn_state=attn_state)
|
attn_state=attn_state)
|
||||||
self.attn_state = attn_state # type: ignore
|
self.attn_state = attn_state # type: ignore
|
||||||
|
|
||||||
@@ -1191,8 +1116,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||||
]
|
]
|
||||||
|
|
||||||
is_only_prefill = bool(np.all(num_valid_tokens != 1))
|
|
||||||
|
|
||||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||||
attn_state,
|
attn_state,
|
||||||
total_num_scheduled_tokens)
|
total_num_scheduled_tokens)
|
||||||
@@ -1202,10 +1125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
||||||
self.with_prefill = with_prefill
|
self.with_prefill = with_prefill
|
||||||
self.num_tokens_across_dp = num_tokens_across_dp
|
self.num_tokens_across_dp = num_tokens_across_dp
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
|
||||||
self.graph_pad_size = padded_num_tokens_across_dp
|
|
||||||
else:
|
|
||||||
self.graph_pad_size = -1
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||||
@@ -1221,7 +1141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
spec_attn_mask=self.spec_attn_mask,
|
spec_attn_mask=self.spec_attn_mask,
|
||||||
attn_state=self.attn_state,
|
attn_state=self.attn_state,
|
||||||
enable_dbo_across_dp=enable_dbo,
|
enable_dbo_across_dp=enable_dbo,
|
||||||
is_only_prefill=is_only_prefill,
|
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
graph_pad_size=self.graph_pad_size,
|
graph_pad_size=self.graph_pad_size,
|
||||||
decode_token_per_req=self.decode_token_per_req,
|
decode_token_per_req=self.decode_token_per_req,
|
||||||
@@ -1248,10 +1168,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Run the multimodal encoder if any.
|
# Run the multimodal encoder if any.
|
||||||
self._execute_mm_encoder(scheduler_output)
|
self._execute_mm_encoder(scheduler_output)
|
||||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||||
else:
|
|
||||||
mm_embeds = []
|
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
|
||||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||||
# embeddings), we always use embeddings (rather than token ids)
|
# embeddings), we always use embeddings (rather than token ids)
|
||||||
# as input to the multimodal model, even when the input is text.
|
# as input to the multimodal model, even when the input is text.
|
||||||
@@ -1273,12 +1190,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# then the embedding layer is not included in the ACL graph.
|
# then the embedding layer is not included in the ACL graph.
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
if self.uses_mrope:
|
positions = self.positions[:num_input_tokens]
|
||||||
positions = self.mrope_positions[:, :num_input_tokens]
|
input_ids, positions = self._update_input_ids_and_positions(
|
||||||
|
input_ids, positions, num_input_tokens, with_prefill,
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
padded_num_tokens_across_dp)
|
||||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
|
||||||
positions = self.positions[:padded_num_tokens_across_dp]
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
@@ -1293,8 +1208,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
moe_comm_method = self.moe_comm_method
|
|
||||||
|
|
||||||
# NOTE: Currently this padding logic is really messy,
|
# NOTE: Currently this padding logic is really messy,
|
||||||
# MC2 may not be available in eager mode
|
# MC2 may not be available in eager mode
|
||||||
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
||||||
@@ -1303,52 +1216,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
num_input_tokens = padded_num_tokens_across_dp
|
num_input_tokens = padded_num_tokens_across_dp
|
||||||
|
|
||||||
# Run forward pass
|
|
||||||
with set_ascend_forward_context(
|
|
||||||
attn_metadata,
|
|
||||||
self.vllm_config,
|
|
||||||
num_tokens=num_input_tokens,
|
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
|
||||||
with_prefill=with_prefill,
|
|
||||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
|
||||||
moe_comm_method=moe_comm_method(self.device, self.dtype,
|
|
||||||
self.model_config.hf_config),
|
|
||||||
num_actual_tokens=total_num_scheduled_tokens):
|
|
||||||
with ProfileExecuteDuration().capture_async("forward"):
|
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
|
||||||
model_kwargs = {}
|
|
||||||
if self.torchair_graph_enabled:
|
|
||||||
model_kwargs["kv_caches"] = self.kv_caches
|
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
|
||||||
if self.torchair_graph_enabled and not with_prefill:
|
|
||||||
maybe_converting_weight_acl_format(self.model,
|
|
||||||
ACL_FORMAT_FRACTAL_NZ)
|
|
||||||
|
|
||||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
|
||||||
padded_num_tokens_across_dp)
|
|
||||||
hidden_states = compiled_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.model is not None
|
|
||||||
maybe_converting_weight_acl_format(self.model,
|
|
||||||
ACL_FORMAT_FRACTAL_ND)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.maybe_wait_for_kv_save()
|
|
||||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
|
||||||
scheduler_output)
|
|
||||||
use_spec_decode = len(
|
use_spec_decode = len(
|
||||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||||
if not use_spec_decode:
|
if not use_spec_decode:
|
||||||
@@ -1358,6 +1225,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# We will ignore the sampled tokens from the partial requests.
|
# We will ignore the sampled tokens from the partial requests.
|
||||||
# TODO: Support prompt logprobs.
|
# TODO: Support prompt logprobs.
|
||||||
spec_decode_metadata = None
|
spec_decode_metadata = None
|
||||||
|
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||||
|
self.device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
# Get the number of draft tokens for each request.
|
# Get the number of draft tokens for each request.
|
||||||
# Iterate over the dictionary rather than all requests since not all
|
# Iterate over the dictionary rather than all requests since not all
|
||||||
@@ -1372,13 +1241,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_draft_tokens, cu_num_tokens)
|
num_draft_tokens, cu_num_tokens)
|
||||||
logits_indices = spec_decode_metadata.logits_indices
|
logits_indices = spec_decode_metadata.logits_indices
|
||||||
|
|
||||||
aux_hidden_states = None
|
return (attn_metadata, positions, num_scheduled_tokens,
|
||||||
if self.use_aux_hidden_state_outputs:
|
num_input_tokens, num_tokens_across_dp,
|
||||||
hidden_states, aux_hidden_states = hidden_states
|
padded_num_tokens_across_dp, logits_indices,
|
||||||
|
spec_decode_metadata, input_ids, inputs_embeds,
|
||||||
|
intermediate_tensors)
|
||||||
|
|
||||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||||
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
|
padded_num_tokens_across_dp,
|
||||||
num_scheduled_tokens, finished_sending, finished_recving)
|
input_ids, positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds):
|
||||||
|
assert self.model is not None
|
||||||
|
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||||
|
num_valid_tokens):
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||||
|
attn_state = AscendAttentionState.PrefillNoCache
|
||||||
|
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||||
|
elif np.all(num_scheduled_tokens == 1):
|
||||||
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||||
|
# SpecDecoding now supports seq_len=1 and seq_len=2
|
||||||
|
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
||||||
|
attn_state = AscendAttentionState.SpecDecoding
|
||||||
|
# Speculative decoding.
|
||||||
|
elif np.all(num_valid_tokens == 1):
|
||||||
|
if self.use_eagle:
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
else:
|
||||||
|
attn_state = AscendAttentionState.SpecDecoding
|
||||||
|
# splitfuse
|
||||||
|
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
|
else:
|
||||||
|
attn_state = AscendAttentionState.PrefillCacheHit
|
||||||
|
return attn_state
|
||||||
|
|
||||||
|
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||||
|
self.graph_pad_size = -1
|
||||||
|
|
||||||
|
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||||
|
num_input_tokens, with_prefill,
|
||||||
|
padded_num_tokens_across_dp):
|
||||||
|
if self.uses_mrope:
|
||||||
|
positions = self.mrope_positions[:, :num_input_tokens]
|
||||||
|
return input_ids, positions
|
||||||
|
|
||||||
def _get_cumsum_and_arange(
|
def _get_cumsum_and_arange(
|
||||||
self,
|
self,
|
||||||
@@ -1623,8 +1540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||||
with ProfileExecuteDuration().capture_async(
|
with ProfileExecuteDuration().capture_async("prepare input"):
|
||||||
"prepare input and forward"):
|
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
if not has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
@@ -1634,11 +1550,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
return self.kv_connector_no_forward(scheduler_output)
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||||
num_scheduled_tokens, logits_indices, aux_hidden_states,
|
num_input_tokens, num_tokens_across_dp,
|
||||||
num_scheduled_tokens_np, finished_sending,
|
padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
|
||||||
finished_recving) = (self._process_reqs(scheduler_output,
|
input_ids, inputs_embeds,
|
||||||
intermediate_tensors))
|
intermediate_tensors) = (self._prepare_inputs(
|
||||||
|
scheduler_output, intermediate_tensors))
|
||||||
|
|
||||||
|
# Run forward pass
|
||||||
|
with ProfileExecuteDuration().capture_async("forward"):
|
||||||
|
with set_ascend_forward_context(
|
||||||
|
attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
with_prefill=self.with_prefill,
|
||||||
|
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||||
|
moe_comm_method=self.moe_comm_method(
|
||||||
|
self.device, self.dtype, self.model_config.hf_config),
|
||||||
|
num_actual_tokens=scheduler_output.
|
||||||
|
total_num_scheduled_tokens):
|
||||||
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
|
||||||
|
hidden_states = self._generate_process_reqs_hidden_states(
|
||||||
|
attn_metadata, self.with_prefill,
|
||||||
|
padded_num_tokens_across_dp, input_ids, positions,
|
||||||
|
intermediate_tensors, inputs_embeds)
|
||||||
|
|
||||||
|
self.maybe_wait_for_kv_save()
|
||||||
|
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
kv_connector_output = None
|
kv_connector_output = None
|
||||||
if finished_sending is not None or finished_recving is not None:
|
if finished_sending is not None or finished_recving is not None:
|
||||||
kv_connector_output = KVConnectorOutput(
|
kv_connector_output = KVConnectorOutput(
|
||||||
@@ -1667,10 +1613,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(hidden_states, num_scheduled_tokens,
|
return self._pool(
|
||||||
num_scheduled_tokens_np,
|
hidden_states,
|
||||||
finished_sending, finished_recving,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
kv_connector_output)
|
num_scheduled_tokens_np, finished_sending,
|
||||||
|
finished_recving, kv_connector_output)
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
if broadcast_pp_output:
|
if broadcast_pp_output:
|
||||||
@@ -1746,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Compute prompt logprobs if needed.
|
# Compute prompt logprobs if needed.
|
||||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||||
hidden_states[:num_scheduled_tokens],
|
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1796,7 +1743,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
scheduler_output,
|
scheduler_output,
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
positions,
|
positions,
|
||||||
num_scheduled_tokens,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
@@ -2191,72 +2138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
m.consumed_memory / float(2**30))
|
m.consumed_memory / float(2**30))
|
||||||
|
|
||||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
|
||||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_model = self.torchair_compiled_models.get(
|
|
||||||
batch_size
|
|
||||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
|
||||||
|
|
||||||
if compiled_model:
|
|
||||||
return compiled_model
|
|
||||||
|
|
||||||
import torchair # type: ignore
|
|
||||||
from torchair import patch_for_hcom # type: ignore
|
|
||||||
|
|
||||||
patch_for_hcom()
|
|
||||||
|
|
||||||
if is_310p():
|
|
||||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
|
||||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
|
||||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
|
||||||
communication_adaptation_310p
|
|
||||||
communication_adaptation_310p()
|
|
||||||
|
|
||||||
config = torchair.CompilerConfig()
|
|
||||||
config.experimental_config.frozen_parameter = True
|
|
||||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
|
||||||
# disable it on 300I Duo platform now.
|
|
||||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
|
||||||
config.experimental_config.enable_view_optimize = \
|
|
||||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
|
||||||
torch.npu.set_compile_mode(jit_compile=False)
|
|
||||||
if not self.use_cached_npu_graph:
|
|
||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
|
||||||
self.torchair_compiled_model = torch.compile(
|
|
||||||
self.model,
|
|
||||||
dynamic=True,
|
|
||||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
|
||||||
backend=npu_backend)
|
|
||||||
return self.torchair_compiled_model
|
|
||||||
else:
|
|
||||||
# Generate a new forward proxy code object to prevent the invalidation of
|
|
||||||
# compilation cache caused by dynamo retracing
|
|
||||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
|
||||||
forward_fn = self.model.forward
|
|
||||||
code = forward_fn.__code__
|
|
||||||
# Mark code object with a new proxy name
|
|
||||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
|
||||||
|
|
||||||
modified_func = types.FunctionType(modified_code,
|
|
||||||
forward_fn.__globals__,
|
|
||||||
name=forward_proxy_name,
|
|
||||||
argdefs=forward_fn.__defaults__)
|
|
||||||
|
|
||||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
|
||||||
self.model, nn.Module)
|
|
||||||
self.torchair_compiled_models[
|
|
||||||
batch_size] = torchair.inference.cache_compile(
|
|
||||||
self.model.__dict__[forward_proxy_name],
|
|
||||||
dynamic=True,
|
|
||||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
|
||||||
config=config,
|
|
||||||
ge_cache=False)
|
|
||||||
return self.torchair_compiled_models[batch_size]
|
|
||||||
|
|
||||||
def _convert_torch_format(self, tensor):
|
def _convert_torch_format(self, tensor):
|
||||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||||
return tensor
|
return tensor
|
||||||
@@ -2707,7 +2588,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
positions[:num_scheduled_tokens],
|
positions[:num_scheduled_tokens],
|
||||||
hidden_states[:num_scheduled_tokens],
|
hidden_states[:num_scheduled_tokens],
|
||||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||||
is_torchair_graph=self.torchair_graph_enabled,
|
is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(),
|
||||||
)
|
)
|
||||||
|
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
@@ -2818,72 +2699,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return prompt_logprobs_dict
|
return prompt_logprobs_dict
|
||||||
|
|
||||||
def init_torchair_graph_batch_sizes(self):
|
|
||||||
start_graph_batch_size = 4
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
|
||||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
|
||||||
|
|
||||||
while (start_graph_batch_size <= self.max_num_reqs):
|
|
||||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
|
||||||
start_graph_batch_size *= 2
|
|
||||||
|
|
||||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
|
||||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
|
||||||
if batch_size <= padded_batch_size:
|
|
||||||
# we treat batch_size as num of requests
|
|
||||||
return padded_batch_size
|
|
||||||
raise ValueError(
|
|
||||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
|
||||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_torchair_graph_batch_sizes(self):
|
|
||||||
# return graph_batch_sizes according to the max number of tokens
|
|
||||||
# first pad according to the number of requests
|
|
||||||
if len(self.torchair_graph_batch_sizes) == 0:
|
|
||||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
|
||||||
else:
|
|
||||||
self.torchair_graph_batch_sizes = sorted(
|
|
||||||
self.torchair_graph_batch_sizes)
|
|
||||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
|
||||||
self.torchair_graph_batch_sizes.pop()
|
|
||||||
if len(self.torchair_graph_batch_sizes) == 0:
|
|
||||||
logger.warning(
|
|
||||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
|
||||||
)
|
|
||||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
|
||||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
|
||||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
|
||||||
|
|
||||||
# padded max number tokens = max_num_req * decode_token_per_req
|
|
||||||
self.torchair_graph_batch_sizes = [
|
|
||||||
graph_batch_size * self.decode_token_per_req
|
|
||||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
|
||||||
]
|
|
||||||
|
|
||||||
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
|
||||||
tp_size = self.parallel_config.tensor_parallel_size
|
|
||||||
if self.parallel_config.enable_expert_parallel:
|
|
||||||
new_graph_batch_sizes = []
|
|
||||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
|
||||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
|
||||||
1) // tp_size * tp_size
|
|
||||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
|
||||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
|
||||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
|
||||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
|
||||||
and self.decode_token_per_req > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
|
||||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
|
||||||
)
|
|
||||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
|
||||||
|
|
||||||
def get_supported_pooling_tasks(self):
|
def get_supported_pooling_tasks(self):
|
||||||
model = self.get_model()
|
model = self.get_model()
|
||||||
if not is_pooling_model(model):
|
if not is_pooling_model(model):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return list(model.pooler.get_supported_tasks())
|
return list(model.pooler.get_supported_tasks())
|
||||||
|
|
||||||
|
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||||
|
return False
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ class MtpProposer:
|
|||||||
device=self.runner.device)
|
device=self.runner.device)
|
||||||
self.torchair_compiled_model = None # type: ignore
|
self.torchair_compiled_model = None # type: ignore
|
||||||
self.torchair_compiled_models = {} # type: ignore
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
|
self.torchair_graph_enabled = get_ascend_config(
|
||||||
|
).torchair_graph_config.enabled
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
@@ -136,7 +138,7 @@ class MtpProposer:
|
|||||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||||
# Replace the last token with the next token.
|
# Replace the last token with the next token.
|
||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
if token_indices is not None and self.runner.torchair_graph_enabled:
|
if token_indices is not None and self.torchair_graph_enabled:
|
||||||
last_token_indices = token_indices
|
last_token_indices = token_indices
|
||||||
|
|
||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
@@ -154,7 +156,7 @@ class MtpProposer:
|
|||||||
# input_batch=self.runner.input_batch,
|
# input_batch=self.runner.input_batch,
|
||||||
# scheduler_output=self.runner.scheduler_output,
|
# scheduler_output=self.runner.scheduler_output,
|
||||||
# )
|
# )
|
||||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
is_running_torchair = self.torchair_graph_enabled and \
|
||||||
not self.runner.with_prefill
|
not self.runner.with_prefill
|
||||||
|
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
@@ -193,7 +195,7 @@ class MtpProposer:
|
|||||||
attn_metadata.prefill.input_positions = target_positions
|
attn_metadata.prefill.input_positions = target_positions
|
||||||
attn_metadata.prefill.seq_lens = seq_lens
|
attn_metadata.prefill.seq_lens = seq_lens
|
||||||
|
|
||||||
if not self.runner.torchair_graph_enabled:
|
if not self.torchair_graph_enabled:
|
||||||
# torch mode need to update num_tokens_across_dp
|
# torch mode need to update num_tokens_across_dp
|
||||||
# TODO: adapt enable_dbo later
|
# TODO: adapt enable_dbo later
|
||||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||||
@@ -216,7 +218,7 @@ class MtpProposer:
|
|||||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
if self.runner.torchair_graph_enabled:
|
if self.torchair_graph_enabled:
|
||||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
@@ -280,12 +282,12 @@ class MtpProposer:
|
|||||||
skip_attn: bool = False,
|
skip_attn: bool = False,
|
||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp=None) -> None:
|
num_tokens_across_dp=None) -> None:
|
||||||
if not self.runner.torchair_graph_enabled:
|
if not self.torchair_graph_enabled:
|
||||||
# TODO: adapt enable_dbo later
|
# TODO: adapt enable_dbo later
|
||||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||||
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
||||||
num_tokens, with_prefill, False)
|
num_tokens, with_prefill, False)
|
||||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
is_running_torchair = self.torchair_graph_enabled and \
|
||||||
not with_prefill
|
not with_prefill
|
||||||
|
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
|
|||||||
Reference in New Issue
Block a user