diff --git a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py index 58d4c709..5cca89d5 100644 --- a/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode/test_v1_spec_decode.py @@ -34,6 +34,10 @@ BASELINES = { "eagle3": [0.68, 0.40, 0.18], } +BASELINES_SP = { + "eagle3": [0.68, 0.40, 0.18], +} + @pytest.fixture def test_prompts(): @@ -371,3 +375,111 @@ def test_llama_qwen_eagle_acceptance( print(f"golden: {golden}") assert match + + +# TODO the function of sp in eagle3 is improving gradually, +# there are still problems when enable sp + dp and some unknown scenes. +# this e2e should also be improving gradually. +@pytest.mark.parametrize("method", ["eagle3"]) +@pytest.mark.parametrize("num_speculative_tokens", [3]) +@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) +@pytest.mark.parametrize("async_scheduling", [True, False]) +def test_eagle3_sp_acceptance( + method: str, + num_speculative_tokens: int, + disable_padded_drafter_batch: bool, + async_scheduling: bool, +): + if disable_padded_drafter_batch and async_scheduling: + pytest.skip( + "skip disable_padded_drafter_batch=True and async_scheduling=True", + ) + + main_model_name = MODELS[method]["main"] + spec_model_name = MODELS[method]["spec"] + + tokenizer = AutoTokenizer.from_pretrained( + main_model_name, + trust_remote_code=True, + ) + sampling_params = SamplingParams( + temperature=0, + ignore_eos=False, + max_tokens=256, + ) + + # sp will only be enabled when query_lens > 1000 + prompts = [ + { + "role": "user", + "content": " " * 1000 + "Hello, my name is", + }, + { + "role": "user", + "content": " " * 1000 + "The president of the United States is", + }, + { + "role": "user", + "content": " " * 1000 + "The capital of France is", + }, + { + "role": "user", + "content": " " * 1000 + "The future of AI is", + }, + ] + prompts = [ + tokenizer.apply_chat_template( + [prompt], + tokenize=False, + add_generation_prompt=True, + ) for prompt in prompts + ] + + speculative_config = { + "method": method, + "num_speculative_tokens": num_speculative_tokens, + "disable_padded_drafter_batch": disable_padded_drafter_batch, + "model": spec_model_name, + } + + compilation_config = CompilationConfig(cudagraph_capture_sizes=[12]) + + with VllmRunner( + main_model_name, + enforce_eager=True, + max_model_len=8192, + disable_log_stats=False, + tensor_parallel_size=1, + max_num_seqs=256, + distributed_executor_backend="mp", + gpu_memory_utilization=0.7, + speculative_config=speculative_config, + compilation_config=compilation_config, + async_scheduling=async_scheduling, + ) as llm: + _ = llm.generate(prompts, sampling_params) + metrics = llm.model.get_metrics() + + num_drafts = 0 + num_accepted_tokens_per_pos = [0] * num_speculative_tokens + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + num_accepted_tokens_per_pos[pos] += metric.values[pos] + + acceptance_per_pos = [ + num_accepted_tokens / num_drafts + for num_accepted_tokens in num_accepted_tokens_per_pos + ] + golden = BASELINES_SP[method] + + match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden)) + if not match: + print(f"acceptance_per_pos: {acceptance_per_pos}") + print(f"golden: {golden}") + + assert match diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 5c037ed4..3e30ecc8 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -275,6 +275,8 @@ class TestEagleProposerDummyRun(TestBase): num_tokens = 32 with_prefill = False + # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` + self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=num_tokens, with_prefill=with_prefill) @@ -284,6 +286,8 @@ class TestEagleProposerDummyRun(TestBase): @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_with_prefill(self, mock_context, mock_get_context): mock_context.return_value.__enter__.return_value = None + # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` + self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4) self.assertTrue(self.proposer.model.call_count == 4) @@ -298,6 +302,8 @@ class TestEagleProposerDummyRun(TestBase): mock_return_context.capturing = True mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True + # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` + self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, in_graph_capturing=True, aclgraph_runtime_mode=CUDAGraphMode.FULL) @@ -316,6 +322,8 @@ class TestEagleProposerDummyRun(TestBase): mock_return_context.capturing = False mock_get_context.return_value = mock_return_context self.proposer.use_cuda_graph = True + # cpu does not support `torch.ops.vllm.maybe_pad_and_reduce` + self.proposer.enable_shared_expert_dp = False self.proposer.dummy_run(num_tokens=64, in_graph_capturing=False, aclgraph_runtime_mode=CUDAGraphMode.FULL) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 1684ab56..694f40c0 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -225,7 +225,6 @@ class EagleProposer(VllmEagleProposer): decode_token_per_req=self.runner.decode_token_per_req, max_seq_len=0, ) - dummy_compute_logits(self.hidden_states) builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata_eagle = builder.build_for_graph_capture( @@ -233,6 +232,10 @@ class EagleProposer(VllmEagleProposer): attn_metadata = {} for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_eagle + + model_input_ids = self.input_ids[:num_tokens] + model_positions = self.positions[:num_tokens] + model_previous_hidden_states = self.hidden_states[:num_tokens] for i in range(self.num_speculative_tokens): if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: aclgraph_runtime_mode = CUDAGraphMode.NONE @@ -245,12 +248,17 @@ class EagleProposer(VllmEagleProposer): batch_descriptor=batch_descriptor, aclgraph_runtime_mode=aclgraph_runtime_mode, is_draft_model=True): - forward_context = get_forward_context() + + if self.enable_shared_expert_dp: + model_previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + model_previous_hidden_states) + self.model( - input_ids=self.input_ids[:num_tokens], - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], + input_ids=model_input_ids, + positions=model_positions, + hidden_states=model_previous_hidden_states, ) + forward_context = get_forward_context() if (forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing): @@ -261,6 +269,12 @@ class EagleProposer(VllmEagleProposer): self.vllm_config, ) + if self.enable_shared_expert_dp: + model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + model_previous_hidden_states, True) + + dummy_compute_logits(self.hidden_states) + def _propose( self, # [num_tokens] @@ -338,10 +352,24 @@ class EagleProposer(VllmEagleProposer): batch_descriptor=batch_descriptor, aclgraph_runtime_mode=aclgraph_runtime_mode, is_draft_model=True): + + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + model_input_ids = self.input_ids[:num_input_tokens] + model_positions = self.positions[:num_input_tokens] + model_hidden_states = self.hidden_states[:num_input_tokens] + + if self.enable_shared_expert_dp: + # split hidden states along sequence dimension + # positions should not be split? + model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + model_hidden_states) + # in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`? + last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], + input_ids=model_input_ids, + positions=model_positions, + hidden_states=model_hidden_states, ) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -352,6 +380,14 @@ class EagleProposer(VllmEagleProposer): num_input_tokens, self.vllm_config, ) + + if self.enable_shared_expert_dp: + # merge hidden states along sequence dimension + last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + last_hidden_states.contiguous(), True) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) + sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) draft_token_ids = logits.argmax(dim=-1) @@ -470,10 +506,23 @@ class EagleProposer(VllmEagleProposer): aclgraph_runtime_mode=aclgraph_runtime_mode, is_draft_model=True): + # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. + # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. + model_input_ids = self.input_ids[:input_batch_size] + model_positions = self.positions[:input_batch_size] + model_hidden_states = self.hidden_states[:input_batch_size] + + if self.enable_shared_expert_dp: + # split hidden states along sequence dimension + # positions should not be split? + model_hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + model_hidden_states) + # in acl-graph, `model_hidden_states` should be copy back to `self.hidden_states`? + last_hidden_states, hidden_states = self.model( - input_ids=self.input_ids[:input_batch_size], - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], + input_ids=model_input_ids, + positions=model_positions, + hidden_states=model_hidden_states, ) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: @@ -483,6 +532,14 @@ class EagleProposer(VllmEagleProposer): input_batch_size, self.vllm_config, ) + + if self.enable_shared_expert_dp: + # merge hidden states along sequence dimension + last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + last_hidden_states.contiguous(), True) + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states.contiguous(), True) + hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5b16e433..a1788d8d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1056,6 +1056,33 @@ class NPUModelRunner(GPUModelRunner): input_ids, inputs_embeds, intermediate_tensors, max_num_scheduled_tokens) + # all-gather one hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states(hidden_states): + hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) + pad_size = get_forward_context().pad_size + if pad_size > 0: + hidden_states = hidden_states[:-pad_size, :] + + return hidden_states + + # all-gather a list of hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states_list(hidden_states_list): + return [ + NPUModelRunner._all_gather_hidden_states(hidden_states) + for hidden_states in hidden_states_list + ] + + # all-gather hidden-states in last layer with aux-hidden-states in sp scene + @staticmethod + def _all_gather_hidden_states_and_aux(hidden_states): + if isinstance(hidden_states, tuple): + return (NPUModelRunner._all_gather_hidden_states(hidden_states[0]), + NPUModelRunner._all_gather_hidden_states_list( + hidden_states[1])) + return NPUModelRunner._all_gather_hidden_states(hidden_states) + def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, input_ids, positions, intermediate_tensors, @@ -1103,10 +1130,8 @@ class NPUModelRunner(GPUModelRunner): if get_forward_context().sp_enabled and not isinstance( hidden_states, IntermediateTensors): - hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) - pad_size = get_forward_context().pad_size - if pad_size > 0: - hidden_states = hidden_states[:-pad_size, :] + hidden_states = self._all_gather_hidden_states_and_aux( + hidden_states) return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( hidden_states)