[5/N][Refactor] torchair model runner refactor (#2216)
There is lot of torchair code in model runner leading the code hard for
maintenance. We'll create new torchair_model_runner to split torchair
related logic. Following the workflow #2203
What's this PR do:
create common function `_capture_model` for capture_model
- vLLM version: v0.10.0
- vLLM main:
1891a265d3
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -23,7 +23,11 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||||
|
write_kv_cache_bytes_to_file)
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
maybe_converting_weight_acl_format)
|
maybe_converting_weight_acl_format)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
@@ -37,6 +41,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
def _get_forward_metadata_across_dp_and_pad(
|
def _get_forward_metadata_across_dp_and_pad(
|
||||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||||
|
"""Override from NPUModelRunner to pad num_tokens"""
|
||||||
if self.dp_size == 1:
|
if self.dp_size == 1:
|
||||||
if not with_prefill:
|
if not with_prefill:
|
||||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||||
@@ -118,3 +123,49 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
def _convert_torch_format(self, kv_cache):
|
def _convert_torch_format(self, kv_cache):
|
||||||
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
kv_cache = torch_npu.npu_format_cast(kv_cache, ACL_FORMAT_FRACTAL_ND)
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
|
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
||||||
|
# Trigger torchair graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
||||||
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
|
cudagraph_num_of_warmups):
|
||||||
|
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||||
|
self._dummy_run(num_tokens, is_torchair_compile=True)
|
||||||
|
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||||
|
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
||||||
|
|
||||||
|
def _capture_model(self):
|
||||||
|
"""Override from NPUModelRunner to use torchair graph capture."""
|
||||||
|
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||||
|
# torchair graph capture can cause some issues, so now we just
|
||||||
|
# temporarily split the codepath for the two different graph patterns.
|
||||||
|
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||||
|
graph_num = len(torchair_graph_batch_sizes)
|
||||||
|
|
||||||
|
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
||||||
|
# If caching is enabled but does not exist, we will compile the model twice. The first
|
||||||
|
# time is used to generate the cache, and the second time is used to load the cache to
|
||||||
|
# skip the overhead caused by Dynamo guard mechanism.
|
||||||
|
logger.info(
|
||||||
|
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
|
||||||
|
0.5 * graph_num, 1.5 * graph_num)
|
||||||
|
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||||
|
NPUPlatform.synchronize()
|
||||||
|
torch._dynamo.reset()
|
||||||
|
self.torchair_compiled_models.clear()
|
||||||
|
if self.use_cached_npu_graph:
|
||||||
|
logger.info(
|
||||||
|
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
||||||
|
0.3 * graph_num, 0.5 * graph_num)
|
||||||
|
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||||
|
0.5 * graph_num, 1.5 * graph_num)
|
||||||
|
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
||||||
|
|
||||||
|
if self.new_kv_cache_bytes > 0:
|
||||||
|
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
||||||
|
self.new_kv_cache_bytes)
|
||||||
|
|||||||
@@ -82,8 +82,6 @@ from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
|||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
|
||||||
write_kv_cache_bytes_to_file)
|
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
ProfileExecuteDuration, is_310p,
|
ProfileExecuteDuration, is_310p,
|
||||||
maybe_converting_weight_acl_format,
|
maybe_converting_weight_acl_format,
|
||||||
@@ -2323,54 +2321,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def _compile_torchair_graph(self, torchair_graph_batch_sizes) -> None:
|
def _capture_model(self):
|
||||||
# Trigger torchair graph capture for specific shapes.
|
if not self.use_aclgraph:
|
||||||
# Capture the large shapes first so that the smaller shapes
|
logger.info("Skipping NPU graph capture for eager mode.")
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
return
|
||||||
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
|
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
|
||||||
cudagraph_num_of_warmups):
|
|
||||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
|
||||||
self._dummy_run(num_tokens, is_torchair_compile=True)
|
|
||||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
|
||||||
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
|
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
|
||||||
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
|
||||||
# torchair graph capture can cause some issues, so now we just
|
|
||||||
# temporarily split the codepath for the two different graph patterns.
|
|
||||||
if self.torchair_graph_enabled:
|
|
||||||
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
|
||||||
graph_num = len(torchair_graph_batch_sizes)
|
|
||||||
|
|
||||||
if self.use_cached_npu_graph and not check_torchair_cache_exist():
|
|
||||||
# If caching is enabled but does not exist, we will compile the model twice. The first
|
|
||||||
# time is used to generate the cache, and the second time is used to load the cache to
|
|
||||||
# skip the overhead caused by Dynamo guard mechanism.
|
|
||||||
logger.info(
|
|
||||||
"Use cached npu graph but cache doesn't exist! Now we compile graph to genetate torchair cache, this usually takes %.1f~%.1f mins.",
|
|
||||||
0.5 * graph_num, 1.5 * graph_num)
|
|
||||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
|
||||||
NPUPlatform.synchronize()
|
|
||||||
torch._dynamo.reset()
|
|
||||||
self.torchair_compiled_models.clear()
|
|
||||||
if self.use_cached_npu_graph:
|
|
||||||
logger.info(
|
|
||||||
"Loading torchair graph cache, this usually takes %.1f~%.1f mins.",
|
|
||||||
0.3 * graph_num, 0.5 * graph_num)
|
|
||||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
|
||||||
0.5 * graph_num, 1.5 * graph_num)
|
|
||||||
self._compile_torchair_graph(torchair_graph_batch_sizes)
|
|
||||||
|
|
||||||
if self.new_kv_cache_bytes > 0:
|
|
||||||
write_kv_cache_bytes_to_file(torch.distributed.get_rank(),
|
|
||||||
self.new_kv_cache_bytes)
|
|
||||||
elif self.use_aclgraph:
|
|
||||||
# Trigger ACL graph capture for specific shapes.
|
# Trigger ACL graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
@@ -2381,9 +2335,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(num_tokens)
|
self._dummy_run(num_tokens)
|
||||||
self._dummy_run(num_tokens)
|
self._dummy_run(num_tokens)
|
||||||
else:
|
|
||||||
logger.info("Skipping NPU graph capture for eager mode.")
|
def capture_model(self) -> None:
|
||||||
return
|
start_time = time.perf_counter()
|
||||||
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||||
|
|
||||||
|
self._capture_model()
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
|||||||
Reference in New Issue
Block a user