[aclgraph] implentment NPUPiecewiseBackend to enable aclgraph (#836)
### What this PR does / why we need it? 1. Implentment `NPUPiecewiseBackend` to enable aclgraph 2. Eable aclgraph by default in V1, but raise error when running deepseek and raise warning when running models except for qwen ### How was this patch tested? CI pass with the new ut --------- Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
102
tests/compile/test_aclgraph.py
Normal file
102
tests/compile/test_aclgraph.py
Normal file
@@ -0,0 +1,102 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Compare the outputs of vLLM with and without aclgraph.
|
||||
|
||||
Run `pytest tests/compile/test_aclgraph.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
from tests.model_utils import check_outputs_equal
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="aclgraph only support on v1")
|
||||
@pytest.mark.skipif(
|
||||
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
|
||||
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
def test_models(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
|
||||
# aclgraph only support on v1
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
temperature=0.0)
|
||||
# TODO: change to use vllmrunner when the registry of custom op is solved
|
||||
# while running pytest
|
||||
vllm_model = LLM(model)
|
||||
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
del vllm_model
|
||||
torch.npu.empty_cache()
|
||||
|
||||
vllm_model = LLM(model, enforce_eager=True)
|
||||
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
del vllm_model
|
||||
torch.npu.empty_cache()
|
||||
|
||||
vllm_aclgraph_outputs_list = []
|
||||
for output in vllm_aclgraph_outputs:
|
||||
vllm_aclgraph_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
vllm_eager_outputs_list = []
|
||||
for output in vllm_eager_outputs:
|
||||
vllm_eager_outputs_list.append(
|
||||
(output.outputs[0].index, output.outputs[0].text))
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=vllm_eager_outputs_list,
|
||||
outputs_1_lst=vllm_aclgraph_outputs_list,
|
||||
name_0="vllm_eager_outputs",
|
||||
name_1="vllm_aclgraph_outputs",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="aclgraph only support on v1")
|
||||
@pytest.mark.skipif(
|
||||
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
|
||||
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
|
||||
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_MODELSCOPE", "True")
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with pytest.raises(NotImplementedError) as excinfo:
|
||||
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=1024,
|
||||
enforce_eager=False)
|
||||
assert "ACL Graph does not support deepseek" in str(excinfo.value)
|
||||
@@ -77,7 +77,7 @@ class VllmRunner:
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
enforce_eager: Optional[bool] = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
|
||||
@@ -72,7 +72,7 @@ def test_ngram_correctness(
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
@@ -85,6 +85,7 @@ def test_ngram_correctness(
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
@@ -135,6 +136,7 @@ def test_eagle_correctness(
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
enforce_eager=True,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
@@ -18,8 +18,7 @@ import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# TODO: revert me when cuda hard code is fixed in 'VllmBackend'
|
||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
@@ -33,6 +32,9 @@ prompts = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
(vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1")),
|
||||
reason="aclgraph not supported in v0.8.5 and v0.8.5.post1")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_models(model: str, dtype: str, max_tokens: int) -> None:
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=False,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
226
vllm_ascend/compilation/piecewise_backend.py
Normal file
226
vllm_ascend/compilation/piecewise_backend.py
Normal file
@@ -0,0 +1,226 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/vllm/compilation/cuda_piecewise_backend.py
|
||||
#
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_aclgraph: bool # the size is in cudagraph_capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
aclgraph: Optional[torch.npu.NPUGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for aclgraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[List[int]] = None
|
||||
|
||||
|
||||
class NPUPiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and aclgraph capturing.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture aclgraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and aclgraph, we will
|
||||
compile it first, and then capture aclgraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: Set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.aclgraph_capture_sizes: Set[int] = set(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture aclgraph
|
||||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.aclgraph_capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_aclgraph=shape in self.aclgraph_capture_sizes,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
if not entry.use_aclgraph:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.aclgraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_config.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
# Since we capture aclgraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a aclgraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
aclgraph = torch.npu.NPUGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of aclgraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the aclgraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.npu.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.npu.graph(aclgraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's aclgraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other npu aclgraph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.aclgraph = aclgraph
|
||||
|
||||
compilation_counter.num_cudagraph_caputured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during npu aclgraph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for aclgraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
@@ -33,7 +33,6 @@ class dummyFusionOp:
|
||||
|
||||
|
||||
def register_dummy_fusion_op() -> None:
|
||||
torch.cuda.CUDAGraph = torch.npu.NPUGraph
|
||||
torch.ops._C.rms_norm = dummyFusionOp(name="rms_norm")
|
||||
torch.ops._C.fused_add_rms_norm = dummyFusionOp(name="fused_add_rms_norm")
|
||||
torch.ops._C.static_scaled_fp8_quant = dummyFusionOp(
|
||||
|
||||
@@ -23,7 +23,6 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
from vllm.utils import supports_dynamo
|
||||
|
||||
from vllm_ascend.utils import ASCEND_QUATIZATION_METHOD, update_aclgraph_sizes
|
||||
|
||||
@@ -119,24 +118,48 @@ class NPUPlatform(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
from vllm.config import CompilationLevel # noqa: E402
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if vllm_config.model_config is None:
|
||||
if model_config is None:
|
||||
logger.warning("Model config is missing. This may indicate "
|
||||
"that we are running a test case")
|
||||
enforce_eager = False
|
||||
else:
|
||||
enforce_eager = getattr(vllm_config.model_config, "enforce_eager",
|
||||
False)
|
||||
enforce_eager = getattr(model_config, "enforce_eager", False)
|
||||
|
||||
# TODO(Yizhou): Override the value of enforce_eager to True before
|
||||
# the CANN and torch_npu support NPU compilation.
|
||||
enforce_eager = True
|
||||
logger.warning(
|
||||
"NPU compilation support pending. Will be available in future CANN and "
|
||||
"torch_npu releases. NPU graph mode is currently experimental and disabled "
|
||||
"by default. You can just adopt additional_config={'enable_graph_mode': True} "
|
||||
"to serve deepseek models with NPU graph mode on vllm-ascend with V0 engine. "
|
||||
)
|
||||
if vllm_config.additional_config is not None:
|
||||
enable_graph_mode = vllm_config.additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
if enable_graph_mode:
|
||||
if enforce_eager:
|
||||
raise RuntimeError(
|
||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||
)
|
||||
elif envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
|
||||
logger.warning(
|
||||
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
||||
"it has been disabled automatically.")
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
if model_config:
|
||||
model_type = model_config.hf_config.model_type
|
||||
if "deepseek" not in model_type:
|
||||
raise NotImplementedError(
|
||||
"enable_graph_mode only works with deepseek model."
|
||||
)
|
||||
|
||||
elif envs.VLLM_USE_V1 and model_config is not None and not enforce_eager:
|
||||
model_type = model_config.hf_config.model_type
|
||||
if "deepseek" in model_type:
|
||||
raise NotImplementedError(
|
||||
"ACL Graph does not support deepseek. Please "
|
||||
"adopt additional_config={'enable_graph_mode': True} "
|
||||
"to serve deepseek models with NPU graph mode on vllm-ascend with V1 engine."
|
||||
" Or set `enforce_eager=True` to use eager mode.")
|
||||
elif "qwen" not in model_type:
|
||||
logger.warning(
|
||||
"ACL Graph is currently experimental. Please "
|
||||
"raise an issue on https://github.com/vllm-project/vllm-ascend/issues"
|
||||
" if you encourage any Error")
|
||||
|
||||
if enforce_eager or compilation_config.level == CompilationLevel.NO_COMPILATION:
|
||||
logger.info("Compilation disabled, using eager mode by default")
|
||||
@@ -155,20 +178,6 @@ class NPUPlatform(Platform):
|
||||
["vllm.unified_ascend_attention_with_output"])
|
||||
update_aclgraph_sizes(vllm_config)
|
||||
|
||||
if vllm_config.additional_config is not None:
|
||||
enable_graph_mode = vllm_config.additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
if enable_graph_mode and not supports_dynamo():
|
||||
logger.warning(
|
||||
"enable_graph_mode is not supported because the version of torch is too low, forcing close enable_graph_mode"
|
||||
)
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
if enable_graph_mode and envs.VLLM_USE_V1 and envs.VLLM_MLA_DISABLE:
|
||||
logger.warning(
|
||||
"NPU graph mode is still experimental and not supported for V1 without mla currently, "
|
||||
"it has been disabled automatically.")
|
||||
vllm_config.additional_config["enable_graph_mode"] = False
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
if parallel_config and parallel_config.worker_cls == "auto":
|
||||
if envs.VLLM_USE_V1:
|
||||
@@ -244,3 +253,10 @@ class NPUPlatform(Platform):
|
||||
model configuration.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
"""
|
||||
Get piecewise backend class for piecewise graph.
|
||||
"""
|
||||
return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa
|
||||
|
||||
Reference in New Issue
Block a user