[BugFix] Fix ACLgraph bug in Qwen3_32b_int8 case (#3204)
### What this PR does / why we need it? 1. Solved the issue where sizes capture failed for the Qwen3-32b-int8 model when aclgraph, dp1, and tp4 were enabled. 2. Added the exception thrown when sizes capture fails and provided a solution 3. Add this common problem to the FAQ doc ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ut - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -196,3 +196,18 @@ export ATB_LLM_LCOC_ENABLE=0
|
|||||||
### 19. How to fix the error "ImportError: Please install vllm[audio] for audio support" for Qwen2.5-Omni model?
|
### 19. How to fix the error "ImportError: Please install vllm[audio] for audio support" for Qwen2.5-Omni model?
|
||||||
The `Qwen2.5-Omni` model requires the `librosa` package to be installed, you need to install the `qwen-omni-utils` package to ensure all dependencies are met `pip install qwen-omni-utils`,
|
The `Qwen2.5-Omni` model requires the `librosa` package to be installed, you need to install the `qwen-omni-utils` package to ensure all dependencies are met `pip install qwen-omni-utils`,
|
||||||
this package will install `librosa` and its related dependencies, resolving the `ImportError: No module named 'librosa'` issue and ensuring audio processing functionality works correctly.
|
this package will install `librosa` and its related dependencies, resolving the `ImportError: No module named 'librosa'` issue and ensuring audio processing functionality works correctly.
|
||||||
|
|
||||||
|
### 20. How to troubleshoot and resolve size capture failures resulting from stream resource exhaustion, and what are the underlying causes?
|
||||||
|
|
||||||
|
```
|
||||||
|
error example in detail:
|
||||||
|
ERROR 09-26 10:48:07 [model_runner_v1.py:3029] ACLgraph sizes capture fail: RuntimeError:
|
||||||
|
ERROR 09-26 10:48:07 [model_runner_v1.py:3029] ACLgraph has insufficient available streams to capture the configured number of sizes.Please verify both the availability of adequate streams and the appropriateness of the configured size count.
|
||||||
|
```
|
||||||
|
|
||||||
|
Recommended mitigation strategies:
|
||||||
|
1. Manually configure the compilation_config parameter with a reduced size set: '{"cudagraph_capture_sizes":[size1, size2, size3, ...]}'.
|
||||||
|
2. Employ ACLgraph's full graph mode as an alternative to the piece-wise approach.
|
||||||
|
|
||||||
|
Root cause analysis:
|
||||||
|
The current stream requirement calculation for size captures only accounts for measurable factors including: data parallel size, tensor parallel size, expert parallel configuration, piece graph count, multistream overlap shared expert settings, and HCCL communication mode (AIV/AICPU). However, numerous unquantifiable elements - such as operator characteristics and specific hardware features - consume additional streams outside of this calculation framework, resulting in stream resource exhaustion during size capture operations.
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class TestUtils(TestBase):
|
|||||||
utils.update_aclgraph_sizes(test_vllm_config)
|
utils.update_aclgraph_sizes(test_vllm_config)
|
||||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
138,
|
137,
|
||||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
|
|
||||||
test_vllm_config.speculative_config = mock.MagicMock()
|
test_vllm_config.speculative_config = mock.MagicMock()
|
||||||
@@ -273,7 +273,7 @@ class TestUtils(TestBase):
|
|||||||
utils.update_aclgraph_sizes(test_vllm_config)
|
utils.update_aclgraph_sizes(test_vllm_config)
|
||||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
112,
|
111,
|
||||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
|
|
||||||
# max_num_batch_sizes >= len(original_sizes)
|
# max_num_batch_sizes >= len(original_sizes)
|
||||||
|
|||||||
@@ -40,14 +40,6 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
VllmConfig = None
|
VllmConfig = None
|
||||||
|
|
||||||
# NOTE: Currently, we can only capture 1800 graphs at most,
|
|
||||||
# due to the limitation of ACL graph. This number is bounded by
|
|
||||||
# the number of streams, which is 2048, we save 248 streams
|
|
||||||
# as a buffer.
|
|
||||||
# Maximum number of graphs that can be captured by ACL Graph
|
|
||||||
# TODO: Find out whether we need to solve allreduce function
|
|
||||||
MAX_CAPTURE_SIZE = 1800
|
|
||||||
|
|
||||||
ASCEND_QUANTIZATION_METHOD = "ascend"
|
ASCEND_QUANTIZATION_METHOD = "ascend"
|
||||||
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"]
|
||||||
REGISTERED_ASCEND_OPS = {}
|
REGISTERED_ASCEND_OPS = {}
|
||||||
@@ -293,6 +285,14 @@ def get_max_hidden_layers(hf_config) -> int:
|
|||||||
|
|
||||||
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
||||||
"""Update ACL graph capture sizes based on hardware limitations"""
|
"""Update ACL graph capture sizes based on hardware limitations"""
|
||||||
|
# NOTE: Currently, we can only capture 1800 graphs at most,
|
||||||
|
# due to the limitation of ACL graph. This number is bounded by
|
||||||
|
# the number of streams, which is 2048, we save 248 streams
|
||||||
|
# as a buffer.
|
||||||
|
# Maximum number of graphs that can be captured by ACL Graph
|
||||||
|
# TODO: Find out whether we need to solve allreduce function
|
||||||
|
MAX_CAPTURE_SIZE = 1800
|
||||||
|
|
||||||
# Store original configuration and temporarily clear it
|
# Store original configuration and temporarily clear it
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
original_sizes, compilation_config.cudagraph_capture_sizes = \
|
||||||
@@ -326,6 +326,11 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
|
|||||||
"multistream_overlap_shared_expert", False))
|
"multistream_overlap_shared_expert", False))
|
||||||
if is_moe_model(vllm_config):
|
if is_moe_model(vllm_config):
|
||||||
parallel_factor += (parallel_config.data_parallel_size > 1)
|
parallel_factor += (parallel_config.data_parallel_size > 1)
|
||||||
|
else:
|
||||||
|
# When AIV mode is enabled, the allreduce operator of the dense
|
||||||
|
# layer model will occupy additional streams, which are buffered here.
|
||||||
|
MAX_CAPTURE_SIZE = MAX_CAPTURE_SIZE - parallel_factor * resources_per_graph
|
||||||
|
|
||||||
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
|
# Calculate maximum supported batch sizes considering model architecture on the A2 Hardware Device
|
||||||
# Assume the following case:
|
# Assume the following case:
|
||||||
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
|
||||||
|
|||||||
@@ -3418,10 +3418,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
aclgraph_runtime_mode = aclgraph_mode.mixed_mode()
|
||||||
|
|
||||||
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
|
compilation_cases = list(reversed(self.aclgraph_batch_sizes))
|
||||||
self._capture_aclgraphs(
|
|
||||||
compilation_cases,
|
try:
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
self._capture_aclgraphs(
|
||||||
uniform_decode=False)
|
compilation_cases,
|
||||||
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
uniform_decode=False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"ACLgraph sizes capture fail: {type(e).__name__}:\n"
|
||||||
|
"ACLgraph has insufficient available streams to capture the configured number of sizes. "
|
||||||
|
"Please verify both the availability of adequate streams and the appropriateness of the configured size count.\n\n"
|
||||||
|
"Recommended solutions:\n"
|
||||||
|
"1. Manually configure the compilation_config parameter "
|
||||||
|
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
|
||||||
|
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
|
||||||
|
f"{str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
|
||||||
aclgraph_mode.separate_routine():
|
aclgraph_mode.separate_routine():
|
||||||
|
|||||||
Reference in New Issue
Block a user