[5/N][Refactor] torchair model runner refactor (#2216)
There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203
What's this PR do:
create common function `_capture_model` for capture_model
- vLLM version: v0.10.0
- vLLM main:
1891a265d3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -23,7 +23,11 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
maybe_converting_weight_acl_format)
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -37,6 +41,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
"""Override from NPUModelRunner to pad num_tokens"""
|
||||
if self.dp_size == 1:
|
||||
if not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
@@ -118,3 +123,49 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
def _convert_torch_format(self, kv_cache):
|
||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||
return kv_cache
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
|
||||
def _capture_model(self):
|
||||
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist, we will compile the model twice. The first
|
||||
# time is used to generate the cache, and the second time is used to load the cache to
|
||||
# skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
|
||||
if self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
|
||||
@@ -82,8 +82,6 @@ from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
maybe_converting_weight_acl_format,
|
||||
@@ -2323,67 +2321,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
def _capture_model(self):
|
||||
if not self.use_aclgraph:
|
||||
logger.info("Skipping NPU graph capture for eager mode.")
|
||||
return
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||
# torchair graph capture can cause some issues, so now we just
|
||||
# temporarily split the codepath for the two different graph patterns.
|
||||
if self.torchair_graph_enabled:
|
||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||
graph_num = len(torchair_graph_batch_sizes)
|
||||
|
||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||
# If caching is enabled but does not exist, we will compile the model twice. The first
|
||||
# time is used to generate the cache, and the second time is used to load the cache to
|
||||
# skip the overhead caused by Dynamo guard mechanism.
|
||||
logger.info(
|
||||
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
NPUPlatform.synchronize()
|
||||
torch._dynamo.reset()
|
||||
self.torchair_compiled_models.clear()
|
||||
if self.use_cached_npu_graph:
|
||||
logger.info(
|
||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||
0.3 * graph_num, 0.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
else:
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||
self._capture_model()
|
||||
|
||||
if self.new_kv_cache_bytes > 0:
|
||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||
self.new_kv_cache_bytes)
|
||||
elif self.use_aclgraph:
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
else:
|
||||
logger.info("Skipping NPU graph capture for eager mode.")
|
||||
return
|
||||
end_time = time.perf_counter()
|
||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
Reference in New Issue
Block a user