diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 43f6aa37..a51c0c36 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -165,7 +165,7 @@ class TestEagleProposerLoadModel(TestBase): self.proposer.load_model(mock_model) mock_get_model.assert_called_once() - self.assertEqual(self.proposer.attn_layer_name, ["layer3"]) + self.assertEqual(self.proposer.attn_layer_names, ["layer3"]) self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) @@ -196,7 +196,7 @@ class TestEagleProposerLoadModel(TestBase): self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) - self.assertEqual(self.proposer.attn_layer_name, ["layer2"]) + self.assertEqual(self.proposer.attn_layer_names, ["layer2"]) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @@ -239,6 +239,8 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.speculative_config.num_speculative_tokens = 4 self.device = torch.device("cpu") self.runner = MagicMock() + self.runner.pcp_size = 1 + self.runner.dcp_size = 1 self.vllm_config.cache_config.block_size = 16 self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 @@ -246,6 +248,7 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 self.vllm_config.model_config.uses_mrope = False + self.vllm_config.model_config.use_mla = False self.vllm_config.speculative_config.speculative_token_tree = str([ (i + 1) * (0, ) for i in range(4) ]) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 7c69c12c..324ae321 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -30,7 +30,7 @@ class TestMtpProposer: config.additional_config = None config.speculative_config = MagicMock(spec=SpeculativeConfig) config.speculative_config.num_speculative_tokens = 2 - config.speculative_config.method = "deepseek_mtp" + config.speculative_config.method = "mtp" config.speculative_config.draft_model_config = MagicMock() config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 config.speculative_config.speculative_token_tree = str([ @@ -98,9 +98,11 @@ class TestMtpProposer: mock_buffer_instance = MagicMock() mock_cpu_gpu_buffer.return_value = mock_buffer_instance runner._use_aclgraph.return_value = True + vllm_config.scheduler_config.async_scheduling = False + vllm_config.speculative_config.enforce_eager = False proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) - assert proposer.use_aclgraph is True + assert proposer.use_cuda_graph is True @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 6695c118..aef66fa4 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -91,23 +91,7 @@ class EagleProposer(VllmEagleProposer): super().__init__(vllm_config, device, runner) self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling - # there is synchronization between mtp steps when enabling aclgraph, - # disable aclgraph when use async scheduling to avoid the - # synchronization overhead. - # NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run - # and _propose. - self.use_cuda_graph = ( - self.vllm_config.compilation_config.mode - == CompilationMode.VLLM_COMPILE - and not self.vllm_config.model_config.enforce_eager - and not self.use_async_scheduling - and not self.vllm_config.speculative_config.enforce_eager) - self.cudagraph_batch_sizes = list( - sorted( - self.vllm_config.compilation_config.cudagraph_capture_sizes)) - - self.pcp_size = self.runner.pcp_size self.decode_threshold = 1 + self.num_speculative_tokens self.query_start_loc = self.runner._make_buffer( self.runner.max_num_reqs + 1, dtype=torch.int32) @@ -118,12 +102,11 @@ class EagleProposer(VllmEagleProposer): self.enable_shared_expert_dp = shared_expert_dp_enabled() + self.pcp_size = self.runner.pcp_size self.dcp_size = self.runner.dcp_size self.pcp_rank = self.runner.pcp_rank self.dcp_rank = self.runner.dcp_rank - self.use_aclgraph = self.runner._use_aclgraph() - self.full_indices = range( self.runner.max_num_tokens * self.pcp_size * self.dcp_size + self.pcp_size * self.dcp_size * self.runner.max_num_reqs) @@ -131,6 +114,10 @@ class EagleProposer(VllmEagleProposer): self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") + self.use_cuda_graph = (self.runner._use_aclgraph() + and not self.speculative_config.enforce_eager + and not self.use_async_scheduling) + # TODO: Remove it when the bug of fx-graph is solved self.maybe_eager_context: ContextManager[Any] = nullcontext() if not self.use_cuda_graph and enable_sp(vllm_config): @@ -158,8 +145,7 @@ class EagleProposer(VllmEagleProposer): draft_indexer_layer_names = indexer_layers - target_indexer_layer_names draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = list(draft_attn_layer_names) - self.attn_layer_names = self.attn_layer_name + self.attn_layer_names = list(draft_attn_layer_names) # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: @@ -273,7 +259,7 @@ class EagleProposer(VllmEagleProposer): attn_metadata_eagle = builder.build_for_graph_capture( common_attn_metadata, AscendAttentionState.ChunkedPrefill) attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_eagle model_input_ids = self.input_ids[:num_tokens] @@ -292,30 +278,22 @@ class EagleProposer(VllmEagleProposer): aclgraph_runtime_mode=aclgraph_runtime_mode, is_draft_model=True): - forward_context = get_forward_context() - if forward_context.sp_enabled: - model_previous_hidden_states = split_inputs_tp_to_sp( - model_previous_hidden_states, - model_previous_hidden_states) + model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_previous_hidden_states, model_positions) self.model( input_ids=model_input_ids, positions=model_positions, hidden_states=model_previous_hidden_states, ) + forward_context = get_forward_context() if (forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing): - update_attn_params( - self.update_stream, - forward_context, - num_tokens, - self.vllm_config, - ) + self._update_full_graph_params(forward_context, num_tokens) - if forward_context.sp_enabled: - model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - model_previous_hidden_states, True) + model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad( + model_previous_hidden_states, model_positions) dummy_compute_logits(self.hidden_states) @@ -362,7 +340,7 @@ class EagleProposer(VllmEagleProposer): self.input_ids[last_token_indices] = next_token_ids if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens <= self.runner.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -386,7 +364,7 @@ class EagleProposer(VllmEagleProposer): # update global cos, sin update_cos_sin(self.positions[:num_input_tokens]) per_layer_attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata with set_ascend_forward_context( per_layer_attn_metadata, @@ -403,34 +381,27 @@ class EagleProposer(VllmEagleProposer): model_positions = self.positions[:num_input_tokens] model_hidden_states = self.hidden_states[:num_input_tokens] - forward_context = get_forward_context() - if forward_context.sp_enabled: - # split hidden states along sequence dimension - # positions should not be split? - model_hidden_states = split_inputs_tp_to_sp( - model_hidden_states, model_hidden_states) + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - # TODO: support mla in future. - update_attn_params( - self.update_stream, - forward_context, - num_input_tokens, - self.vllm_config, - ) + self._update_full_graph_params(forward_context, + num_input_tokens) - if forward_context.sp_enabled: - # merge hidden states along sequence dimension - last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True) - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), True) + last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( + last_hidden_states, model_positions, hidden_states) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -453,7 +424,7 @@ class EagleProposer(VllmEagleProposer): last_token_indices = self.arange[:batch_size] if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + batch_size <= self.runner.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -556,32 +527,27 @@ class EagleProposer(VllmEagleProposer): model_positions = self.positions[:input_batch_size] model_hidden_states = self.hidden_states[:input_batch_size] - forward_context = get_forward_context() - if forward_context.sp_enabled: - # split hidden states along sequence dimension - # positions should not be split? - model_hidden_states = split_inputs_tp_to_sp( - model_hidden_states, model_hidden_states) + model_hidden_states, model_positions = self.maybe_pad_and_reduce( + model_hidden_states, model_positions) - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=model_input_ids, positions=model_positions, hidden_states=model_hidden_states, ) - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - update_attn_params( - self.update_stream, - forward_context, - input_batch_size, - self.vllm_config, - ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states - if forward_context.sp_enabled: - # merge hidden states along sequence dimension - last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - last_hidden_states.contiguous(), True) - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), True) + forward_context = get_forward_context() + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + self._update_full_graph_params(forward_context, + input_batch_size) + + last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( + last_hidden_states, model_positions, hidden_states) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) @@ -948,3 +914,46 @@ class EagleProposer(VllmEagleProposer): else: update_attn_params(self.update_stream, forward_context, num_tokens, self.vllm_config) + + def maybe_pad_and_reduce( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.method == "mtp": + if self.enable_shared_expert_dp: + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states) + positions = positions.unsqueeze(-1) + positions = torch.ops.vllm.maybe_pad_and_reduce(positions) + positions = positions.squeeze(-1) + else: + forward_context = get_forward_context() + if forward_context.sp_enabled: + hidden_states = split_inputs_tp_to_sp( + hidden_states, hidden_states) + return hidden_states, positions + + def maybe_all_gather_and_unpad( + self, + last_hidden_states: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + if self.method == "mtp": + if self.enable_shared_expert_dp: + last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + last_hidden_states.contiguous(), True) + positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + positions.contiguous(), True) + if hidden_states is not None: + hidden_states = last_hidden_states + else: + forward_context = get_forward_context() + if forward_context.sp_enabled: + last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + last_hidden_states.contiguous(), True) + if hidden_states is not None: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) + return last_hidden_states, positions, hidden_states diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 2f4ac4ad..f2710050 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -89,7 +89,7 @@ class MtpProposer(EagleProposer): attn_metadata_mtp = builder.build_for_graph_capture( common_attn_metadata, attn_state) attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp else: attn_metadata = None @@ -112,12 +112,8 @@ class MtpProposer(EagleProposer): batch_descriptor=batch_descriptor, is_draft_model=True, in_profile_run=is_profile): - if self.enable_shared_expert_dp: - positions = positions.unsqueeze(-1) - positions = torch.ops.vllm.maybe_pad_and_reduce(positions) - positions = positions.squeeze(-1) - previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( - previous_hidden_states) + previous_hidden_states, positions = self.maybe_pad_and_reduce( + previous_hidden_states, positions) self.model(input_ids=input_ids, positions=positions, hidden_states=previous_hidden_states) @@ -126,11 +122,8 @@ class MtpProposer(EagleProposer): not forward_context.capturing and not self.use_sparse: self._update_full_graph_params(forward_context, num_tokens) - if self.enable_shared_expert_dp: - positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - positions, True) - previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - previous_hidden_states, True) + previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad( + previous_hidden_states, positions) dummy_compute_logits(previous_hidden_states) if with_prefill: break @@ -249,11 +242,11 @@ class MtpProposer(EagleProposer): assert self.runner is not None # Note(qcs): We may need to refactor these check logics. - if self.runner.use_aclgraph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[ + if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[ -1]: num_input_tokens = self.vllm_config.pad_for_cudagraph( num_scheduled_tokens) - elif self.use_aclgraph and num_tokens <= self.runner.cudagraph_batch_sizes[ + elif self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[ -1]: # Acl graph mode, add padding to the batch size num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -304,7 +297,7 @@ class MtpProposer(EagleProposer): attn_metadata_mtp = builder.build(0, common_attn_metadata, self.runner.get_model()) attn_metadata = {} - for layer_name in self.attn_layer_name: + for layer_name in self.attn_layer_names: attn_metadata[layer_name] = attn_metadata_mtp for step in range(self.num_speculative_tokens): @@ -324,26 +317,8 @@ class MtpProposer(EagleProposer): positions = self.positions[:num_input_tokens] hidden_states = self.hidden_states[:num_input_tokens] - if self.enable_shared_expert_dp: - # positions [N] -> [N, 1] for padding - positions = positions.unsqueeze(-1) - positions = torch.ops.vllm.maybe_pad_and_reduce( - positions) - positions = positions.squeeze(-1) - hidden_states = torch.ops.vllm.maybe_pad_and_reduce( - hidden_states) - - for layer_name in self.attn_layer_name: - decode_metadata = getattr(attn_metadata[layer_name], - "decode", None) - if self.use_async_scheduling and decode_metadata is not None: - actual_size = len( - decode_metadata.actual_seq_lengths_q) - - decode_metadata.seq_lens_list = \ - decode_metadata.seq_lens_list[:actual_size] - decode_metadata.block_table = \ - decode_metadata.block_table[:actual_size] + hidden_states, positions = self.maybe_pad_and_reduce( + hidden_states, positions) hidden_states = self.model(input_ids=input_ids, positions=positions, @@ -353,11 +328,8 @@ class MtpProposer(EagleProposer): self._update_full_graph_params(forward_context, num_input_tokens) - if self.enable_shared_expert_dp: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states.contiguous(), True) - positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - positions.contiguous(), True) + hidden_states, positions, _ = self.maybe_all_gather_and_unpad( + hidden_states, positions) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): @@ -398,7 +370,7 @@ class MtpProposer(EagleProposer): if step == self.num_speculative_tokens - 1 or with_prefill: break - attn_metadata_i = attn_metadata[self.attn_layer_name[0]] + attn_metadata_i = attn_metadata[self.attn_layer_names[0]] if step == 0: positions = target_positions[last_token_indices]