diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index f207c64d..150fbeec 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -7,9 +7,10 @@ import random from typing import Any import pytest +from transformers import AutoTokenizer from vllm import LLM, SamplingParams -from tests.e2e.conftest import VllmRunner +from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -115,41 +116,67 @@ def test_eagle_correctness( Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. ''' - pytest.skip("To be aligned with GPU") - ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=False) - ref_outputs = ref_llm.chat(test_prompts, sampling_config) - del ref_llm - + # NOTE: e2e of eagle has many problems before. + # We first check whether it is functioning properly. + # Should fix the e2e with VllmRunner in future. spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name() - with VllmRunner( - model_name, - max_num_seqs=1, - max_num_batched_tokens=2048, - gpu_memory_utilization=0.6, - speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", - "model": spec_model_name, - "num_speculative_tokens": 2, - "max_model_len": 128, - }, - max_model_len=128, - enforce_eager=False, - ) as runner: - spec_outputs = runner.model.chat(test_prompts, sampling_config) + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompts = [{ + "role": "user", + "content": "Hello, my name is" + }, { + "role": "user", + "content": "The president of the United States is" + }, { + "role": "user", + "content": "The capital of France is" + }, { + "role": "user", + "content": "The future of AI is" + }] + prompts = [ + tokenizer.apply_chat_template( + [prompt], + tokenize=False, + add_generation_prompt=True, + ) for prompt in prompts + ] - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + sampling_params = SamplingParams( + max_tokens=300, + temperature=0.0, + ignore_eos=False, + ) - # Heuristic: expect at least 66% of the prompts to match exactly - # Upon failure, inspect the outputs to check for inaccuracy. - assert matches > int(0.66 * len(ref_outputs)) + # Create an LLM. + llm = LLM( + model=model_name, + tensor_parallel_size=1, + pipeline_parallel_size=1, + data_parallel_size=1, + disable_log_stats=False, + max_model_len=4096, + seed=1024, + async_scheduling=True, + compilation_config={ + "level": 3, + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_num_of_warmups": 1, + "cudagraph_capture_sizes": [12], + }, + speculative_config={ + "disable_padded_drafter_batch": False, + "method": "eagle3" if use_eagle3 else "eagle", + "model": spec_model_name, + "num_speculative_tokens": 2, + "max_model_len": 128, + "draft_vocab_size": 128256, + }, + ) + llm.generate(prompts, sampling_params) + cleanup_dist_env_and_memory() + del llm @pytest.mark.skip( diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index cbfda43a..f688aa95 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -26,6 +26,13 @@ class TestEagleProposerInitialization(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.mock_cpugpubuffer = patch( + "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + self.mock_cpugpubuffer.start() + + def tearDown(self): + self.mock_cpugpubuffer.stop() + def test_initialization_eagle(self): self.vllm_config.speculative_config.method = "eagle" self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096 @@ -44,7 +51,7 @@ class TestEagleProposerInitialization(TestBase): self.assertEqual(proposer.input_ids.shape, (1024, )) self.assertEqual(proposer.positions.shape, (1024, )) self.assertEqual(proposer.hidden_states.shape, (1024, 4096)) - self.assertEqual(proposer.arange.shape, (33, )) + self.assertEqual(proposer.arange.shape, (1024, )) def test_initialization_eagle3(self): self.vllm_config.speculative_config.method = "eagle3" @@ -77,10 +84,16 @@ class TestEagleProposerLoadModel(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.mock_cpugpubuffer = patch( + "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + self.mock_cpugpubuffer.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) + def tearDown(self): + self.mock_cpugpubuffer.stop() + @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @patch("vllm_ascend.spec_decode.eagle_proposer.get_model") @@ -172,11 +185,17 @@ class TestEagleProposerDummyRun(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.mock_cpugpubuffer = patch( + "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + self.mock_cpugpubuffer.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) self.proposer.model = MagicMock() + def tearDown(self): + self.mock_cpugpubuffer.stop() + @patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context") def test_dummy_run_basic(self, mock_context): num_tokens = 32 @@ -216,6 +235,9 @@ class TestEagleProposerGenerateTokenIds(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.mock_cpugpubuffer = patch( + "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + self.mock_cpugpubuffer.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) @@ -223,7 +245,12 @@ class TestEagleProposerGenerateTokenIds(TestBase): self.proposer._propose = MagicMock( return_value=torch.tensor([[1, 2], [3, 4], [5, 6]])) - def test_generate_token_ids_without_metadata(self): + def tearDown(self): + self.mock_cpugpubuffer.stop() + + # TODO: This is equivalent to disable_padded_drafter_batch=True. + # We need to add some cases about disable_padded_drafter_batch=False in future. + def test_generate_token_ids(self): valid_sampled = [[20, 30, 40]] scheduler_output = MagicMock() scheduler_output.num_scheduled_tokens = [2, 1, 3] @@ -239,7 +266,7 @@ class TestEagleProposerGenerateTokenIds(TestBase): return_value={"layer_0": mock_attn_metadata}) result = self.proposer.generate_token_ids( - valid_sampled_token_ids=valid_sampled, + sampled_token_ids=valid_sampled, scheduler_output=scheduler_output, positions=positions, num_scheduled_tokens=num_scheduled, @@ -247,36 +274,13 @@ class TestEagleProposerGenerateTokenIds(TestBase): ) self.proposer._propose.assert_called_once() - self.assertEqual(result, [[1, 2], [3, 4], [5, 6]]) - - def test_generate_token_ids_with_metadata(self): - valid_sampled = [[5], [6, 7], [8, 9, 10]] - spec_metadata = MagicMock() - spec_metadata.num_draft_tokens = [2, 3, 4] - - mock_attn_metadata = MagicMock() - mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) - mock_attn_metadata.query_start_loc = torch.tensor([0, 1, 3, 6]) - mock_attn_metadata.block_tables = MagicMock() - self.proposer._get_eagle_atten_dict = MagicMock( - return_value={"layer_0": mock_attn_metadata}) - self.proposer._prepare_inputs = MagicMock( - return_value=(torch.tensor([0, 2, 5]), torch.tensor([1, 3, 5]))) - - result = self.proposer.generate_token_ids( - valid_sampled_token_ids=valid_sampled, - spec_decode_metadata=spec_metadata, - positions=torch.randn(6, 1), - hidden_states=torch.randn(6, 4096), - ) - - self.proposer._prepare_inputs.assert_called_once() - self.assertEqual(self.proposer._propose.call_count, 1) - self.assertEqual(len(result), 3) + self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]]) class TestEagleProposerHelperMethods(TestBase): + # TODO: Can add some tests about prepare_next_token_ids in future. + def setUp(self): self.vllm_config = MagicMock(spec=VllmConfig) self.vllm_config.scheduler_config = MagicMock(max_num_seqs=3) @@ -293,21 +297,29 @@ class TestEagleProposerHelperMethods(TestBase): self.vllm_config.model_config.dtype = torch.float16 self.vllm_config.model_config.max_model_len = 2048 + self.mock_cpugpubuffer = patch( + "vllm_ascend.spec_decode.eagle_proposer.CpuGpuBuffer") + self.mock_cpugpubuffer.start() self.proposer = EagleProposer(vllm_config=self.vllm_config, device=self.device, runner=self.runner) + def tearDown(self): + self.mock_cpugpubuffer.stop() + + # TODO: This is equivalent to disable_padded_drafter_batch=True. + # We need to add a test_prepare_inputs_padded in future. def test_prepare_inputs(self): self.proposer.token_arange_np = np.arange(10) mock_attn = MagicMock() mock_attn.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) num_rejected = torch.tensor([1, 0, 1], device=self.device) + mock_return_attn = MagicMock() with patch.object(self.proposer, - '_prepare_inputs', - return_value=(torch.tensor([0, 2, 5]), + 'prepare_inputs', + return_value=(mock_return_attn, torch.tensor([1, 2, 4]))): - cu_num_tokens, indices = self.proposer._prepare_inputs( + return_attn, indices = self.proposer.prepare_inputs( mock_attn, num_rejected) - self.assertEqual(cu_num_tokens.tolist(), [0, 2, 5]) self.assertEqual(indices.tolist(), [1, 2, 4]) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ad22b0dc..7faa30a9 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -730,6 +730,9 @@ class AscendAttentionBackendImpl(AttentionImpl): self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping if get_ascend_device_type() == AscendDeviceType._910_95: + # TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping. + # Should check if the 0 dim of slot_mapping must equal to the 0 dim of key. + # If it's necessary, the slots should be sliced. torch_npu.npu_scatter_pa_kv_cache( key=key[:attn_metadata.num_actual_tokens], value=value[:attn_metadata.num_actual_tokens].contiguous(), @@ -742,7 +745,7 @@ class AscendAttentionBackendImpl(AttentionImpl): value=value[:attn_metadata.num_actual_tokens], key_cache=self.key_cache, value_cache=self.value_cache, - slot_indices=slots) + slot_indices=slots[:attn_metadata.num_actual_tokens]) return key, value def forward_impl( diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index cb95871a..454190a5 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -119,6 +119,35 @@ class AscendCommonAttentionMetadata: prefill_context_parallel_metadata: Optional[ AscendPrefillContextParallelMetadata] = None + # TODO: Remove it when vLLM no longer uses this function. + def unpadded(self, num_actual_tokens: int, + num_actual_reqs: int) -> "AscendCommonAttentionMetadata": + # This only use to eagle now. It will be use to enforce_eager in future. + return AscendCommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_actual_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_actual_reqs + 1], + seq_lens=self.seq_lens[:num_actual_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_actual_reqs], + num_computed_tokens_cpu=self. + num_computed_tokens_cpu[:num_actual_reqs], + num_reqs=num_actual_reqs, + num_actual_tokens=num_actual_tokens, + max_query_len=self.max_query_len, + decode_token_per_req=self.decode_token_per_req, + block_table_tensor=self.block_table_tensor[:num_actual_reqs], + slot_mapping=self.slot_mapping[:num_actual_tokens], + actual_seq_lengths_q=self.actual_seq_lengths_q[:num_actual_tokens], + positions=self.positions[:num_actual_tokens], + attn_mask=self.attn_mask, + spec_attn_mask=self.spec_attn_mask, + attn_state=self.attn_state, + is_only_prefill=self.is_only_prefill, + graph_pad_size=-1, # It should be -1 when not run in fullgraph mode. + num_input_tokens=num_actual_tokens, + prefill_context_parallel_metadata=self. + prefill_context_parallel_metadata, + ) + def filter_chunked_req_indices( seq_len: torch.Tensor, diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index a71c3537..24a846d9 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -14,19 +14,25 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder -from vllm_ascend.attention.attention_v1 import (AscendAttentionState, - AscendMetadata) +from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType PADDING_SLOT_ID = -1 +_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' + +_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'} + class EagleProposer(Proposer): @@ -54,6 +60,19 @@ class EagleProposer(Proposer): sorted( self.vllm_config.compilation_config.cudagraph_capture_sizes)) + max_batch_size = vllm_config.scheduler_config.max_num_seqs + # Currently we do not use pcp. This is used to adapt the pcp branch. + self.pcp_size = 0 + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True, + ) + self.decode_threshold = 1 + \ + self.vllm_config.speculative_config.num_speculative_tokens + # persistent buffers for cuda graph self.input_ids = torch.zeros( self.vllm_config.scheduler_config.max_num_batched_tokens, @@ -71,12 +90,13 @@ class EagleProposer(Proposer): self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + - 1, + max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1) + self.arange = torch.arange(max_num_slots_for_arange, device=device, dtype=torch.int32) + self.arange_cpu = torch.arange(max_num_slots_for_arange, + device="cpu", + dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) def load_model(self, model: nn.Module) -> None: @@ -135,8 +155,7 @@ class EagleProposer(Proposer): dummy_compute_logits(self.hidden_states) def generate_token_ids(self, - valid_sampled_token_ids: torch.Tensor - | list[list[int]], + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata = None, scheduler_output: SchedulerOutput = None, spec_decode_metadata: SpecDecodeMetadata = None, @@ -144,273 +163,155 @@ class EagleProposer(Proposer): num_scheduled_tokens: int = 0, hidden_states: torch.Tensor = None, aux_hidden_states: torch.Tensor = None): + common_attn_metadata = self.runner.spec_decode_common_attn_metadata - attn_metadata = self._get_eagle_atten_dict(scheduler_output) - next_token_ids: list[int] = [] - for i, token_ids in enumerate(valid_sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.runner.input_batch.req_ids[i] - req_state = self.runner.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - eagle_attn_metadata = attn_metadata[self.attn_layer_name] - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids.gpu[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.name == SpecDcodeType.EAGLE3: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.prepare_next_token_ids_cpu( + sampled_token_ids, self.runner.requests, + self.runner.input_batch, scheduler_output.num_scheduled_tokens) else: - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor( - num_rejected_tokens, - dtype=torch.int32, - device=self.device, - ) - cu_num_tokens, token_indices =\ - self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens) - target_token_ids = self.runner.input_ids.gpu[token_indices] - target_positions = positions[token_indices] - if self.name == SpecDcodeType.EAGLE3: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ + self.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.runner.requests, + self.runner.input_batch, + self.runner.discard_request_indices.gpu, + self.runner.num_discarded_requests + ) + self._copy_valid_sampled_token_count(next_token_ids, + valid_sampled_tokens_count) + + req_scheduled_tokens = scheduler_output.num_scheduled_tokens + if self.pcp_size > 1: + long_seq_metadata = self.runner.long_seq_metadata + input_ids_pcp_full = self.runner.input_ids_pcp_full + query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full + query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + num_reqs = self.runner.input_batch.num_reqs + ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + query_start_loc_pcp_full_cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + else: + long_seq_metadata = None + num_prefill_reqs = 0 + num_decode_reqs = 0 + if spec_decode_metadata is None: + # update pcp related params + if self.pcp_size > 1: + token_indices_to_sample = \ + query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1 + target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states else: - target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.runner.input_ids.gpu[: + num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[:num_scheduled_tokens] + else: + if self.pcp_size > 1: + common_attn_metadata.query_start_loc_cpu = \ + query_start_loc_pcp_full_cpu[:num_reqs + 1] + common_attn_metadata.query_start_loc = \ + query_start_loc_pcp_full[:num_reqs + 1] + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # NOTE: Currently, MTP-fullgraph is incompatibility with pcp + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + if self.pcp_size > 1: + target_token_ids = input_ids_pcp_full[token_indices] + target_positions = positions + target_hidden_states = hidden_states + else: + target_token_ids = self.runner.input_ids.gpu[token_indices] + target_positions = positions[token_indices] + if self.name == SpecDcodeType.EAGLE3: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], dim=-1) + else: + target_hidden_states = hidden_states[token_indices] draft_token_ids = self._propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=eagle_attn_metadata.block_tables, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids - def _get_eagle_atten_dict( - self, - scheduler_output: "SchedulerOutput", - ): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.runner.input_batch.num_reqs - assert num_reqs > 0 - - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.runner.input_batch.block_table.commit_block_table(num_reqs) - - # Get the number of scheduled tokens for each request. - req_ids = self.runner.input_batch.req_ids - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) - self.runner.query_lens = torch.from_numpy(num_scheduled_tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.runner.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - # Get positions. - positions_np = self.runner.positions.np[:total_num_scheduled_tokens] - np.add(self.runner.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.runner.uses_mrope: - self.runner._calc_mrope_positions(scheduler_output) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = ( - positions_np + - req_indices * self.runner.input_batch.token_ids_cpu.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select( - self.runner.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.runner.input_ids.cpu[:total_num_scheduled_tokens]) - - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - # NOTE(Chen): there is exactly one KV cache group that contains all - # attetnion layers in the model for now, so the current logic for - # getting attn_metadata is not related to kv_cache_group information. - # Will extend this part to support multiple KV cache groups later. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.runner.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table = self.runner.input_batch.block_table[ - kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping.np[:total_num_scheduled_tokens]) - - # Prepare the attention metadata. - self.runner.query_start_loc.np[0] = 0 - self.runner.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens - - self.runner.seq_lens.np[:num_reqs] = ( - self.runner.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - - # Copy the tensors to the NPU. - self.runner.input_ids.gpu[:total_num_scheduled_tokens].copy_( - self.runner.input_ids.cpu[:total_num_scheduled_tokens], - non_blocking=True) - if self.runner.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.runner.mrope_positions.gpu[:, :total_num_scheduled_tokens] \ - .copy_( - self.runner. - mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - else: - # Common case (1D positions) - self.runner.positions.gpu[:total_num_scheduled_tokens].copy_( - self.runner.positions.cpu[:total_num_scheduled_tokens], - non_blocking=True) - - self.runner.query_start_loc.gpu[:num_reqs + 1].copy_( - self.runner.query_start_loc.cpu[:num_reqs + 1], non_blocking=True) - self.runner.seq_lens.gpu[:num_reqs].copy_( - self.runner.seq_lens.cpu[:num_reqs], non_blocking=True) - - # Fill unused with -1. Needed for reshape_and_cache - self.runner.seq_lens.gpu[num_reqs:].fill_(0) - self.runner.query_start_loc.gpu[num_reqs + 1:].fill_(-1) - - attn_metadata = {} - # Prepare the attention metadata for each KV cache group and make layers - # in the same group share the same metadata. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.runner.kv_cache_config.kv_cache_groups): - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=self.runner.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.runner.query_start_loc.cpu[:num_reqs + - 1], - seq_lens_cpu=self.runner.seq_lens.cpu, - num_reqs=num_reqs, - max_query_len=max_num_scheduled_tokens, - num_actual_tokens=total_num_scheduled_tokens, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping=self.runner.input_batch.block_table[0]. - slot_mapping.gpu, - positions=self.runner.positions.gpu, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=None, - seq_lens=None) - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata_i = builder.build(0, common_attn_metadata, - self.runner.get_model()) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i - - return attn_metadata - - def _get_cumsum_and_arange( - self, - num_tokens: np.ndarray, - cumsum_dtype: Optional[np.dtype] = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_tokens]) - """ - # Step 1. [2, 5, 3] -> [2, 7, 10] - cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) - total_num_tokens = cu_num_tokens[-1] - # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) - # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.runner.arange_np[:total_num_tokens] - cumsums_offsets - - return cu_num_tokens, arange + return draft_token_ids def _propose( self, # [num_tokens] target_token_ids: torch.Tensor, - # [num_tokens] + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, + last_token_indices: Optional[torch.Tensor], + common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, + req_scheduled_tokens=None, + long_seq_metadata=None, + num_prefill_reqs=0, + num_decode_reqs=0, + scheduler_output: SchedulerOutput = None, + num_scheduled_tokens: int = 0, ) -> torch.Tensor: - device = cu_num_tokens.device - cu_num_tokens = cu_num_tokens.cpu() - block_table = block_table.cpu() + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 - target_positions = target_positions.cpu() + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + if self.name == SpecDcodeType.EAGLE3: assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( @@ -423,34 +324,7 @@ class EagleProposer(Proposer): # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids - seq_lens = (target_positions[last_token_indices] + 1).int() - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - attn_mask = self.runner.attn_mask - - common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=cu_num_tokens.to(device), - query_start_loc_cpu=cu_num_tokens, - seq_lens_cpu=seq_lens.cpu(), - max_query_len=max_query_len, - num_reqs=batch_size, - num_actual_tokens=num_tokens, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=self.runner.input_batch.block_table[0]. - get_device_tensor(), - slot_mapping=target_slot_mapping, - positions=target_positions, - attn_mask=attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=None, - seq_lens=None) - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - builder = self.runner.attn_groups[0][0].get_metadata_builder() - attn_metadata = builder.build(0, common_attn_metadata, - self.runner.get_model()) if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -458,9 +332,14 @@ class EagleProposer(Proposer): num_input_tokens = num_tokens # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions.to(device) + self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states - attn_metadata.block_tables = block_table.to(device) + + # FIXME(woosuk): The below two ops cause synchronization. Optimize. + builder = self.runner.attn_groups[0][0].get_metadata_builder() + attn_metadata = builder.build(0, common_attn_metadata, + self.runner.get_model()) + with set_ascend_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -482,12 +361,14 @@ class EagleProposer(Proposer): draft_token_ids_tensor = torch.zeros( (self.vllm_config.speculative_config.num_speculative_tokens, *draft_token_ids.shape), - dtype=draft_token_ids.dtype) + dtype=draft_token_ids.dtype, + device=self.device) draft_token_ids_tensor[0] = draft_token_ids - positions_cpu = target_positions[last_token_indices].cpu().to( - torch.int64) + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] + last_token_indices = self.arange[:batch_size] + if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) @@ -496,16 +377,14 @@ class EagleProposer(Proposer): attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] + attn_metadata.query_start_loc = self.arange_cpu[:batch_size + 1] attn_metadata.query_start_loc_list = attn_metadata.query_start_loc[ 1:].tolist() attn_metadata.num_decodes, attn_metadata.num_prefills, attn_metadata.num_decode_tokens, attn_metadata.num_prefill_tokens = 0, batch_size, 0, batch_size attn_metadata.num_actual_tokens_pcp_padded = attn_metadata.num_decode_tokens + attn_metadata.num_prefill_tokens - query_lens.fill_(1) - attn_metadata.query_lens = query_lens attn_metadata.actual_seq_lengths_q = [1 + i for i in range(batch_size)] - attn_metadata.seq_lens_list = seq_lens.tolist() + attn_metadata.seq_lens_list = attn_metadata.seq_lens.tolist() attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill for now_speculative in range( self.vllm_config.speculative_config.num_speculative_tokens - @@ -513,8 +392,8 @@ class EagleProposer(Proposer): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_tensor[now_speculative].to(device) - positions_cpu += 1 + input_ids = draft_token_ids_tensor[now_speculative] + positions += 1 # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex @@ -522,16 +401,15 @@ class EagleProposer(Proposer): # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens # generated with this adjustment should be ignored. - exceeds_max_model_len = positions_cpu >= self.vllm_config.model_config.max_model_len + exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions_cpu = torch.where(exceeds_max_model_len, 0, - positions_cpu) - clamped_positions = clamped_positions_cpu.to(device) + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) # TODO: Increment the sequence lengths. - attn_metadata.seq_lens += 1 + attn_metadata.seq_lens = attn_metadata.seq_lens + 1 attn_metadata.seq_lens_list = [ _ + 1 for _ in attn_metadata.seq_lens_list ] @@ -542,22 +420,22 @@ class EagleProposer(Proposer): # TODO: sequence length to 1 to minimize their overheads in attention. # Compute the slot mapping. - block_numbers = (clamped_positions_cpu // self.block_size) - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) + block_numbers = (clamped_positions // self.block_size) + block_ids = attn_metadata.block_tables.gather( + dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - slot_mapping_cpu = ( + slot_mapping_tmp = ( block_ids * self.vllm_config.cache_config.block_size + - clamped_positions_cpu % self.block_size) + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - slot_mapping_cpu.masked_fill_(exceeds_max_model_len, + slot_mapping_tmp.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) # NOTE: ASCEND slot_mapping must on cpu - attn_metadata.slot_mapping = slot_mapping_cpu.to( - torch.int32).to(device) + attn_metadata.slot_mapping[:slot_mapping_tmp.shape[0]].copy_( + slot_mapping_tmp.to(torch.int32)) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions @@ -565,7 +443,6 @@ class EagleProposer(Proposer): attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask() attn_metadata.attn_mask = attn_mask - attn_metadata.block_tables = block_table.to(device) # Run the model. with set_ascend_forward_context(attn_metadata, self.vllm_config, @@ -581,49 +458,188 @@ class EagleProposer(Proposer): # TODO(wenlong): get more than one token for tree attention draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_tensor[now_speculative + 1] = draft_token_ids.cpu() + draft_token_ids_tensor[now_speculative + 1] = draft_token_ids # [batch_size, num_speculative_tokens] draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids - def _prepare_inputs( + def _get_attn_metadata(self, attn_metadata): + if attn_metadata is not None and isinstance(attn_metadata, dict): + architecture = self.vllm_config.model_config.architecture + layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER) + attn_metadata = attn_metadata[layer_name] + + return attn_metadata + + def prepare_next_token_ids_cpu( self, - eagle_attn_metadata: AscendMetadata, - # [batch_size] - num_rejected_tokens: torch.Tensor, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[ + req_id] + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = discard_request_indices[: + num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) + + return next_token_ids, valid_sampled_tokens_count + + def _copy_valid_sampled_token_count( + self, next_token_ids: torch.Tensor, + valid_sampled_tokens_count: torch.Tensor) -> None: + if self.runner.valid_sampled_token_count_event is not None: + default_stream = torch.npu.current_stream() + # initialize a new stream to overlap the copy operation with + # prepare_input of draft model. + with torch.npu.stream( + self.runner.valid_sampled_token_count_copy_stream): + self.runner.valid_sampled_token_count_copy_stream.wait_stream( + default_stream) # type: ignore + self.runner.valid_sampled_token_count_cpu[: + valid_sampled_tokens_count + .shape[0]].copy_( + valid_sampled_tokens_count, + non_blocking=True + ) + self.runner.valid_sampled_token_count_event.record() + + self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze( + 1) + + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. """ # E.g. # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] + # [0, q1, q1 + q2, q1 + q2 + q3] # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] # num_rejected_tokens: [n1, n2, n3] # This function computes the intermediate values: # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] # And returns: # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - num_rejected_tokens_cpu = num_rejected_tokens.to("cpu") - cu_target_query_lens = eagle_attn_metadata.query_start_loc - device = eagle_attn_metadata.query_start_loc.device - query_start_loc_cpu = cu_target_query_lens.to("cpu") + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + num_actual_reqs = len(num_draft_tokens) + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: + num_actual_reqs + + 1] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs] + new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() # [q1 - n1, q2 - n2, q3 - n3] -> @@ -631,7 +647,8 @@ class EagleProposer(Proposer): new_query_start_loc_cpu = torch.zeros( query_start_loc_cpu.shape, dtype=torch.int32, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) @@ -646,8 +663,8 @@ class EagleProposer(Proposer): # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded + token_offests = (self.token_arange_np[:total_num_tokens] - + new_query_start_locs_expanded) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> @@ -656,21 +673,101 @@ class EagleProposer(Proposer): old_query_start_locs_expanded = np.repeat( query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded token_indices = torch.from_numpy(token_indices_np).to( device, non_blocking=True) - # need use npu - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - num_tokens_per_req = query_len_per_req - num_rejected_tokens + common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( + common_attn_metadata.slot_mapping[token_indices]) + common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) - # [a - n1, b - n2, c - n3] -> - # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.zeros_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. + spec_common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + num_input_tokens=common_attn_metadata.num_input_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + positions=common_attn_metadata.positions[token_indices], + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + ) + return spec_common_attn_metadata, token_indices - return cu_num_tokens, token_indices + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1], + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = query_start_loc_cpu[ + 1:] - query_start_loc_cpu[:-1] + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + # NOTE: Currently positions and seq_lens are not used in mla_v1 forward + # so we do not need to fixed them. But if they are used in the future, + # we should fixed them. + spec_common_attn_metadata = AscendCommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + num_input_tokens=common_attn_metadata.num_input_tokens, + max_query_len=new_query_len_per_req.max().item(), + actual_seq_lengths_q=self.runner.actual_seq_lengths_q, + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + positions=common_attn_metadata.positions, + attn_mask=self.runner.attn_mask, + spec_attn_mask=self.runner.spec_attn_mask, + attn_state=self.runner.attn_state, + decode_token_per_req=self.runner.decode_token_per_req, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + seq_lens=common_attn_metadata.seq_lens) + + token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] - + 1 - num_rejected_tokens_gpu) + + return spec_common_attn_metadata, token_indices, token_indices_to_sample diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f8de7729..d4b4b25b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -801,7 +801,8 @@ class NPUModelRunner(GPUModelRunner): self.requests[r].num_tokens for r in self.input_batch.req_ids ] num_tokens_np = np.array(num_tokens, dtype=np.int32) - num_reqs = self.input_batch.num_reqs + base_num_reqs = self.input_batch.num_reqs + num_reqs = base_num_reqs if self.pcp_size > 1: # while pcp > 1, we need the original num_scheduled_tokens before split # to calculate discard_requests_mask @@ -1106,6 +1107,11 @@ class NPUModelRunner(GPUModelRunner): if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: self.spec_decode_common_attn_metadata = common_attn_metadata + if self.speculative_config.method in ("eagle", "eagle3") and \ + self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): + self.spec_decode_common_attn_metadata = \ + self.spec_decode_common_attn_metadata.unpadded( + total_num_scheduled_tokens, base_num_reqs) for attn_group in self.attn_groups[kv_cache_group_id]: common_prefix_len = 0 @@ -1591,7 +1597,7 @@ class NPUModelRunner(GPUModelRunner): with ProfileExecuteDuration().capture_async("Draft"): if self.speculative_config: use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.method == "mtp" and \ + self.speculative_config.use_eagle() and \ not self.speculative_config.disable_padded_drafter_batch if use_padded_batch_for_eagle: # EAGLE speculative decoding can use the GPU sampled tokens