refact runner model v1 (#2461)
refact model runner v1
### What this PR does / why we need it?
1. Separate the execute model logic from the prepare input logic
2. Disassemble the torchchair in model runner v1
- vLLM version: v0.10.0
- vLLM main:
68fcd3fa73
---------
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -17,21 +17,29 @@
|
||||
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
||||
#
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
register_torchair_model,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
maybe_converting_weight_acl_format)
|
||||
is_310p, maybe_converting_weight_acl_format)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
@@ -39,6 +47,24 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
ascend_config = get_ascend_config()
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.check_torchair_graph_batch_sizes()
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self._check_batch_sizes_consistency()
|
||||
register_torchair_model()
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
@@ -180,3 +206,215 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
if self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
def _use_aclgraph(self) -> bool:
|
||||
return False
|
||||
|
||||
def _check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
|
||||
if not with_prefill:
|
||||
self.graph_pad_size = graph_pad_size
|
||||
else:
|
||||
super()._update_graph_pad_size(with_prefill, graph_pad_size)
|
||||
|
||||
def _update_input_ids_and_positions(self, input_ids, positions,
|
||||
num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp):
|
||||
"""Override from NPUModelRunner to update input_ids and positions"""
|
||||
input_ids, positions = super()._update_input_ids_and_positions(
|
||||
input_ids, positions, num_input_tokens, with_prefill,
|
||||
padded_num_tokens_across_dp)
|
||||
|
||||
if not with_prefill:
|
||||
input_ids = self.input_ids[:padded_num_tokens_across_dp]
|
||||
positions = self.positions[:padded_num_tokens_across_dp]
|
||||
return input_ids, positions
|
||||
|
||||
def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
|
||||
padded_num_tokens_across_dp,
|
||||
input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds):
|
||||
model_kwargs = {
|
||||
"kv_caches": self.kv_caches,
|
||||
"attn_metadata": attn_metadata
|
||||
}
|
||||
if not with_prefill:
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
compiled_model = self._get_torchair_lazy_compiled_model(
|
||||
padded_num_tokens_across_dp)
|
||||
hidden_states = compiled_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
assert self.model is not None
|
||||
maybe_converting_weight_acl_format(self.model,
|
||||
ACL_FORMAT_FRACTAL_ND)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
**model_kwargs,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
|
||||
raise ValueError(
|
||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
|
||||
)
|
||||
|
||||
compiled_model = self.torchair_compiled_models.get(
|
||||
batch_size
|
||||
) if self.use_cached_npu_graph else self.torchair_compiled_model
|
||||
|
||||
if compiled_model:
|
||||
return compiled_model
|
||||
|
||||
import torchair # type: ignore
|
||||
from torchair import patch_for_hcom # type: ignore
|
||||
|
||||
patch_for_hcom()
|
||||
|
||||
if is_310p():
|
||||
# on 300I Duo platform, we need to patch broadcast. however, this patch will be
|
||||
# overwritten by patch_for_hcom in torchair. so we need to re-patch it here.
|
||||
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
|
||||
communication_adaptation_310p
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
config.experimental_config.frozen_parameter = True
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
# compilation cache caused by dynamo retracing
|
||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||
forward_fn = self.model.forward
|
||||
code = forward_fn.__code__
|
||||
# Mark code object with a new proxy name
|
||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||
|
||||
modified_func = types.FunctionType(modified_code,
|
||||
forward_fn.__globals__,
|
||||
name=forward_proxy_name,
|
||||
argdefs=forward_fn.__defaults__)
|
||||
|
||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||
self.model, nn.Module)
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
config=config,
|
||||
ge_cache=False)
|
||||
return self.torchair_compiled_models[batch_size]
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (start_graph_batch_size <= self.max_num_reqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size:
|
||||
# we treat batch_size as num of requests
|
||||
return padded_batch_size
|
||||
raise ValueError(
|
||||
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
|
||||
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
|
||||
)
|
||||
|
||||
def check_torchair_graph_batch_sizes(self):
|
||||
# return graph_batch_sizes according to the max number of tokens
|
||||
# first pad according to the number of requests
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = sorted(
|
||||
self.torchair_graph_batch_sizes)
|
||||
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.pop()
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
logger.warning(
|
||||
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
|
||||
)
|
||||
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
|
||||
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
|
||||
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
|
||||
|
||||
# padded max number tokens = max_num_req * decode_token_per_req
|
||||
self.torchair_graph_batch_sizes = [
|
||||
graph_batch_size * self.decode_token_per_req
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes
|
||||
]
|
||||
|
||||
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
new_graph_batch_sizes = []
|
||||
for graph_batch_size in self.torchair_graph_batch_sizes:
|
||||
cur_graph_batch_size = (graph_batch_size + tp_size -
|
||||
1) // tp_size * tp_size
|
||||
if cur_graph_batch_size not in new_graph_batch_sizes and \
|
||||
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
|
||||
new_graph_batch_sizes.append(cur_graph_batch_size)
|
||||
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
|
||||
and self.decode_token_per_req > 1:
|
||||
logger.warning(
|
||||
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
|
||||
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
|
||||
)
|
||||
self.torchair_graph_batch_sizes = new_graph_batch_sizes
|
||||
|
||||
def _build_drafter_prepare_inputs_torchair_param(self):
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user