refact runner model v1 (#2461)
refact model runner v1
### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1
- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -17,21 +17,29 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
register_torchair_model,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
maybe_converting_weight_acl_format)
|
||||
is_310p, maybe_converting_weight_acl_format)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
ascend_config = get_ascend_config()
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.check_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
register_torchair_model()
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
@@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
if self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if not with_prefill:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
else:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if not with_prefill:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def check_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return True
|
||||
|
||||
@@ -22,7 +22,6 @@ import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
||||
@@ -39,7 +38,6 @@ from vllm.attention.layer import Attention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
@@ -108,7 +106,6 @@ else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
@@ -341,11 +338,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=True)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
self.use_aclgraph = (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and self.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager
|
||||
and not ascend_config.torchair_graph_config.enabled)
|
||||
self.use_aclgraph = self._use_aclgraph()
|
||||
self.aclgraph_batch_sizes = list(
|
||||
reversed(self.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
@@ -357,31 +350,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.check_torchair_graph_batch_sizes()
|
||||
|
||||
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
|
||||
self.graph_block_tables = np.zeros(
|
||||
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
|
||||
(self.model_config.max_model_len + self.block_size - 1) //
|
||||
self.block_size),
|
||||
dtype=np.int32)
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self.check_batch_sizes_consistency()
|
||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||
self.in_profile_run = False
|
||||
|
||||
@@ -400,27 +368,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.moe_comm_method = AllGatherCommImpl
|
||||
|
||||
def check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove finished requests from the cached states.
|
||||
@@ -1047,14 +996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def _process_reqs(
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
|
||||
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
|
||||
Optional[set[str]], Optional[set[str]]]:
|
||||
AscendTorchairMetadata], torch.Tensor, np.ndarray, int,
|
||||
torch.Tensor, int, torch.Tensor, SpecDecodeMetadata,
|
||||
Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -1103,9 +1053,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cu_num_tokens = np.cumsum(num_scheduled_tokens)
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
|
||||
num_scheduled_tokens)
|
||||
logits_indices = cu_num_tokens - 1
|
||||
logits_indices = torch.from_numpy(logits_indices).to(self.device,
|
||||
non_blocking=True)
|
||||
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
|
||||
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
@@ -1118,7 +1065,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
@@ -1127,7 +1073,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions[total_num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
self.query_lens = torch.from_numpy(num_scheduled_tokens)
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
@@ -1145,34 +1091,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# SpecDecoding now supports seq_len=1 and seq_len=2
|
||||
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.use_eagle:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
|
||||
self.attn_mask = self._make_attention_mask(
|
||||
seq_lens=seq_lens,
|
||||
query_lens=num_scheduled_tokens,
|
||||
position=positions,
|
||||
position=self.positions[:num_input_tokens],
|
||||
attn_state=attn_state)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
@@ -1191,8 +1116,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
is_only_prefill = bool(np.all(num_valid_tokens != 1))
|
||||
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
@@ -1202,10 +1125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
||||
self.with_prefill = with_prefill
|
||||
self.num_tokens_across_dp = num_tokens_across_dp
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
self.graph_pad_size = padded_num_tokens_across_dp
|
||||
else:
|
||||
self.graph_pad_size = -1
|
||||
self._update_graph_pad_size(with_prefill, padded_num_tokens_across_dp)
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
@@ -1221,7 +1141,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
attn_state=self.attn_state,
|
||||
enable_dbo_across_dp=enable_dbo,
|
||||
is_only_prefill=is_only_prefill,
|
||||
is_only_prefill=bool(np.all(num_valid_tokens != 1)),
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
graph_pad_size=self.graph_pad_size,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
@@ -1248,10 +1168,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Run the multimodal encoder if any.
|
||||
self._execute_mm_encoder(scheduler_output)
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output)
|
||||
else:
|
||||
mm_embeds = []
|
||||
|
||||
if self.is_multimodal_model:
|
||||
# NOTE(woosuk): To unify token ids and soft tokens (vision
|
||||
# embeddings), we always use embeddings (rather than token ids)
|
||||
# as input to the multimodal model, even when the input is text.
|
||||
@@ -1273,12 +1190,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# then the embedding layer is not included in the ACL graph.
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
inputs_embeds = None
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
input_ids, positions = self._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -1293,8 +1208,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
moe_comm_method = self.moe_comm_method
|
||||
|
||||
# NOTE: Currently this padding logic is really messy,
|
||||
# MC2 may not be available in eager mode
|
||||
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
||||
@@ -1303,52 +1216,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
num_input_tokens = padded_num_tokens_across_dp
|
||||
|
||||
# Run forward pass
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(self.device, self.dtype,
|
||||
self.model_config.hf_config),
|
||||
num_actual_tokens=total_num_scheduled_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
model_kwargs = {}
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||
scheduler_output)
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
@@ -1358,6 +1225,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
spec_decode_metadata = None
|
||||
logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
|
||||
self.device, non_blocking=True)
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
@@ -1372,13 +1241,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
return (attn_metadata, positions, num_scheduled_tokens,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
padded_num_tokens_across_dp, logits_indices,
|
||||
spec_decode_metadata, input_ids, inputs_embeds,
|
||||
intermediate_tensors)
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens, finished_sending, finished_recving)
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND)
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens):
|
||||
ascend_config = get_ascend_config()
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
# SpecDecoding now supports seq_len=1 and seq_len=2
|
||||
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.use_eagle:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
return attn_state
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
self.graph_pad_size = -1
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions[:, :num_input_tokens]
|
||||
return input_ids, positions
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@@ -1623,8 +1540,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||
with ProfileExecuteDuration().capture_async(
|
||||
"prepare input and forward"):
|
||||
with ProfileExecuteDuration().capture_async("prepare input"):
|
||||
self._update_states(scheduler_output)
|
||||
if not scheduler_output.total_num_scheduled_tokens:
|
||||
if not has_kv_transfer_group():
|
||||
@@ -1634,11 +1550,41 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
return self.kv_connector_no_forward(scheduler_output)
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens, logits_indices, aux_hidden_states,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
(attn_metadata, positions, num_scheduled_tokens_np,
|
||||
num_input_tokens, num_tokens_across_dp,
|
||||
padded_num_tokens_across_dp, logits_indices, spec_decode_metadata,
|
||||
input_ids, inputs_embeds,
|
||||
intermediate_tensors) = (self._prepare_inputs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
# Run forward pass
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=self.with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=self.moe_comm_method(
|
||||
self.device, self.dtype, self.model_config.hf_config),
|
||||
num_actual_tokens=scheduler_output.
|
||||
total_num_scheduled_tokens):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
hidden_states = self._generate_process_reqs_hidden_states(
|
||||
attn_metadata, self.with_prefill,
|
||||
padded_num_tokens_across_dp, input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = self.get_finished_kv_transfer(
|
||||
scheduler_output)
|
||||
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
kv_connector_output = None
|
||||
if finished_sending is not None or finished_recving is not None:
|
||||
kv_connector_output = KVConnectorOutput(
|
||||
@@ -1667,10 +1613,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np,
|
||||
finished_sending, finished_recving,
|
||||
kv_connector_output)
|
||||
return self._pool(
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving, kv_connector_output)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
@@ -1746,7 +1693,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
hidden_states[:scheduler_output.total_num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
@@ -1796,7 +1743,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output,
|
||||
spec_decode_metadata,
|
||||
positions,
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
hidden_states,
|
||||
attn_metadata,
|
||||
aux_hidden_states,
|
||||
@@ -2191,72 +2138,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
m.consumed_memory / float(2**30))
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
@@ -2707,7 +2588,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions[:num_scheduled_tokens],
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||
is_torchair_graph=self.torchair_graph_enabled,
|
||||
is_torchair_graph=self._build_drafter_prepare_inputs_torchair_param(),
|
||||
)
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
@@ -2818,72 +2699,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def check_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def get_supported_pooling_tasks(self):
|
||||
model = self.get_model()
|
||||
if not is_pooling_model(model):
|
||||
return []
|
||||
|
||||
return list(model.pooler.get_supported_tasks())
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return False
|
||||
|
||||
@@ -48,6 +48,8 @@ class MtpProposer:
|
||||
device=self.runner.device)
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.torchair_graph_enabled = get_ascend_config(
|
||||
).torchair_graph_config.enabled
|
||||
|
||||
@staticmethod
|
||||
def prepare_inputs(
|
||||
@@ -136,7 +138,7 @@ class MtpProposer:
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
if token_indices is not None and self.runner.torchair_graph_enabled:
|
||||
if token_indices is not None and self.torchair_graph_enabled:
|
||||
last_token_indices = token_indices
|
||||
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
@@ -154,7 +156,7 @@ class MtpProposer:
|
||||
# input_batch=self.runner.input_batch,
|
||||
# scheduler_output=self.runner.scheduler_output,
|
||||
# )
|
||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not self.runner.with_prefill
|
||||
|
||||
if is_running_torchair:
|
||||
@@ -193,7 +195,7 @@ class MtpProposer:
|
||||
attn_metadata.prefill.input_positions = target_positions
|
||||
attn_metadata.prefill.seq_lens = seq_lens
|
||||
|
||||
if not self.runner.torchair_graph_enabled:
|
||||
if not self.torchair_graph_enabled:
|
||||
# torch mode need to update num_tokens_across_dp
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -216,7 +218,7 @@ class MtpProposer:
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
model_kwargs = {}
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.runner.torchair_graph_enabled:
|
||||
if self.torchair_graph_enabled:
|
||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||
if is_running_torchair:
|
||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
@@ -280,12 +282,12 @@ class MtpProposer:
|
||||
skip_attn: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None) -> None:
|
||||
if not self.runner.torchair_graph_enabled:
|
||||
if not self.torchair_graph_enabled:
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._get_forward_metadata_across_dp_and_pad(
|
||||
num_tokens, with_prefill, False)
|
||||
is_running_torchair = self.runner.torchair_graph_enabled and \
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not with_prefill
|
||||
|
||||
if is_running_torchair:
|
||||
|
||||
Reference in New Issue
Block a user