diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 0d264c7..463a30d 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -261,6 +261,20 @@ class TestUtils(TestBase): self.assertEqual( 147, len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + + test_vllm_config.speculative_config = mock.MagicMock() + test_vllm_config.speculative_config.draft_model_config = mock.MagicMock( + ) + test_vllm_config.speculative_config.draft_model_config.hf_config = mock.MagicMock( + ) + test_vllm_config.speculative_config.draft_model_config.hf_config.num_hidden_layers = 2 + os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV' + utils.update_aclgraph_sizes(test_vllm_config) + del os.environ['HCCL_OP_EXPANSION_MODE'] + self.assertEqual( + 120, + len(test_vllm_config.compilation_config.cudagraph_capture_sizes)) + # max_num_batch_sizes >= len(original_sizes) test_compilation_config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 3]) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index adab490..f3c1aef 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -304,6 +304,12 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: num_hidden_layers = get_max_hidden_layers(hf_config) parallel_config = vllm_config.parallel_config + # Calculate maximum supported batch sizes considering model architecture + resources_per_graph = num_hidden_layers + 1 + if vllm_config.speculative_config is not None: + draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config + resources_per_graph += draft_model_hf_config.num_hidden_layers + 1 + # TODO: Find out whether we need to take into account the pp_size num_comm_groups = sum(size > 1 for size in [ parallel_config.data_parallel_size, @@ -318,8 +324,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Assume the following case: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19 - max_num_batch_sizes = math.floor( - MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor) + max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE / + resources_per_graph / parallel_factor) logger.info( "Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes) @@ -335,8 +341,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4, # According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12 max_num_batch_sizes = math.floor( - (MAX_CAPTURE_SIZE - num_comm_groups * 40) / - (num_hidden_layers + 1) / (1 + num_comm_groups * 2)) + (MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph / + (1 + num_comm_groups * 2)) logger.info( "Calculated maximum supported batch sizes for ACL graph: %s", max_num_batch_sizes)