diff --git a/docs/source/developer_guide/feature_guide/index.md b/docs/source/developer_guide/feature_guide/index.md index 6ceb74cd..9c92246c 100644 --- a/docs/source/developer_guide/feature_guide/index.md +++ b/docs/source/developer_guide/feature_guide/index.md @@ -9,7 +9,6 @@ patch ModelRunner_prepare_inputs disaggregated_prefill eplb_swift_balancer.md -Multi_Token_Prediction ACL_Graph KV_Cache_Pool_Guide add_custom_aclnn_op diff --git a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md b/docs/source/user_guide/feature_guide/Multi_Token_Prediction.md similarity index 100% rename from docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md rename to docs/source/user_guide/feature_guide/Multi_Token_Prediction.md diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index 0a763eae..bc452aaa 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -12,6 +12,7 @@ structured_output lora eplb_swift_balancer netloader +Multi_Token_Prediction dynamic_batch kv_pool external_dp diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 0c85d9d7..2ecd7db1 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -144,9 +144,17 @@ class TestEagleProposerLoadModel(TestBase): def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 1 - mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()} - mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()} - mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + mock_target_layer1 = MagicMock() + mock_target_layer2 = MagicMock() + mock_draft_layer1 = MagicMock() + mock_draft_layer3 = MagicMock() + mock_get_layers.side_effect = [{ + "layer1": mock_target_layer1, + "layer2": mock_target_layer2 + }, {}, {}, { + "layer1": mock_draft_layer1, + "layer3": mock_draft_layer3 + }] mock_model = MagicMock() mock_model.model.embed_tokens = MagicMock() @@ -158,7 +166,7 @@ class TestEagleProposerLoadModel(TestBase): self.proposer.load_model(mock_model) mock_get_model.assert_called_once() - self.assertEqual(self.proposer.attn_layer_name, "layer3") + self.assertEqual(self.proposer.attn_layer_name, ["layer3"]) self.assertIs(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) @@ -169,9 +177,14 @@ class TestEagleProposerLoadModel(TestBase): def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_get_layers): mock_pp_group.return_value.world_size = 2 - mock_target_layers = {"layer1": MagicMock()} - mock_draft_layers = {"layer2": MagicMock()} - mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + mock_target_layer1 = MagicMock() + mock_draft_layer2 = MagicMock() + + mock_get_layers.side_effect = [{ + "layer1": mock_target_layer1 + }, {}, {}, { + "layer2": mock_draft_layer2 + }] mock_model = MagicMock() original_embed = MagicMock() @@ -184,7 +197,7 @@ class TestEagleProposerLoadModel(TestBase): self.assertIsNot(self.proposer.model.model.embed_tokens, mock_model.model.embed_tokens) - self.assertEqual(self.proposer.attn_layer_name, "layer2") + self.assertEqual(self.proposer.attn_layer_name, ["layer2"]) @patch( "vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config") @@ -200,9 +213,14 @@ class TestEagleProposerLoadModel(TestBase): mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) - mock_target_layers = {"layer1": MagicMock()} - mock_draft_layers = {"layer2": MagicMock()} - mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers] + mock_target_layer1 = MagicMock() + mock_draft_layer2 = MagicMock() + + mock_get_layers.side_effect = [{ + "layer1": mock_target_layer1 + }, {}, {}, { + "layer2": mock_draft_layer2 + }] mock_pp_group.return_value.world_size = 2 self.proposer.model = MagicMock() @@ -307,83 +325,6 @@ class TestEagleProposerDummyRun(TestBase): self.proposer.use_cuda_graph = last_use_cuda_graph -class TestEagleProposerGenerateTokenIds(TestBase): - - def setUp(self): - self.vllm_config = MagicMock(spec=VllmConfig) - self.vllm_config.speculative_config = MagicMock() - self.vllm_config.speculative_config.method = "eagle" - self.device = torch.device("cpu") - self.runner = MagicMock() - self.runner.input_batch = MagicMock() - self.runner.input_batch.req_ids = [0, 1, 2] - self.runner.requests = { - 0: MagicMock(get_token_id=lambda x: 100), - 1: MagicMock(get_token_id=lambda x: 101), - 2: MagicMock(get_token_id=lambda x: 102), - } - self.runner.pcp_size = 1 - - self.vllm_config.cache_config.block_size = 16 - self.vllm_config.scheduler_config.max_num_batched_tokens = 1024 - self.vllm_config.scheduler_config.max_num_seqs = 32 - self.vllm_config.model_config.dtype = torch.float16 - self.vllm_config.model_config.max_model_len = 2048 - self.vllm_config.model_config.uses_mrope = False - self.vllm_config.speculative_config.num_speculative_tokens = 2 - self.vllm_config.speculative_config.speculative_token_tree = str([ - (i + 1) * (0, ) for i in range(2) - ]) - self.vllm_config.additional_config = None - init_ascend_config(self.vllm_config) - - self.mock_cpugpubuffer = patch( - "vllm.v1.spec_decode.eagle.CpuGpuBuffer") - self.mock_cpugpubuffer.start() - self.mock_supports_multimodal_inputs = patch( - "vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs" - ) - self.mock_supports_multimodal_inputs.start() - self.proposer = EagleProposer(vllm_config=self.vllm_config, - device=self.device, - runner=self.runner) - self.proposer.attn_layer_name = "layer_0" - self.proposer._propose = MagicMock( - return_value=torch.tensor([[1, 2], [3, 4], [5, 6]])) - - def tearDown(self): - self.mock_cpugpubuffer.stop() - self.mock_supports_multimodal_inputs.stop() - - # TODO: This is equivalent to disable_padded_drafter_batch=True. - # We need to add some cases about disable_padded_drafter_batch=False in future. - def test_generate_token_ids(self): - valid_sampled = [[20, 30, 40]] - scheduler_output = MagicMock() - scheduler_output.num_scheduled_tokens = [2, 1, 3] - positions = torch.tensor([0, 1, 2, 3, 4, 5]) - hidden_states = torch.randn(6, 4096) - num_scheduled = 6 - - mock_attn_metadata = MagicMock() - mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5]) - mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6]) - mock_attn_metadata.block_tables = MagicMock() - self.proposer._get_eagle_atten_dict = MagicMock( - return_value={"layer_0": mock_attn_metadata}) - - result = self.proposer.generate_token_ids( - sampled_token_ids=valid_sampled, - scheduler_output=scheduler_output, - positions=positions, - num_scheduled_tokens=num_scheduled, - hidden_states=hidden_states, - ) - - self.proposer._propose.assert_called_once() - self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]]) - - class TestEagleProposerHelperMethods(TestBase): # TODO: Can add some tests about prepare_next_token_ids in future. diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 163bcfb8..c3d62dc5 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -6,12 +6,8 @@ import torch from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_config import init_ascend_config @@ -107,53 +103,6 @@ class TestMtpProposer: assert proposer.use_aclgraph is True - @patch("vllm.config.get_layers_from_vllm_config") - @patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader") - @patch( - "vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading") - @patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype") - @patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config") - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config, - mock_set_dtype, mock_process_weights, mock_get_loader, - mock_get_layers, vllm_config, runner): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - attn_layers_all = { - "target_attn_layer": "val0", - "draft_attn_layer": "val1", - "draft_attn_exclude_by_indexer": "val2", - } - - indexer_layers_all = { - "target_indexer_0": "val3", - "draft_attn_exclude_by_indexer": "val4" - } - - def get_layers_side_effect(vllm_config, cache_cls): - if cache_cls == AttentionLayerBase: - return attn_layers_all - elif cache_cls == DeepseekV32IndexerCache: - return indexer_layers_all - else: - return {} - - # Setup - proposer = MtpProposer(vllm_config, torch.device("cpu"), runner) - proposer._init_mtp_model = MagicMock() - mock_model = MagicMock() - proposer.model = mock_model - - mock_loader = MagicMock() - mock_get_loader.return_value = mock_loader - mock_loader.get_all_weights.return_value = { - "dummy_weight": torch.tensor([1.0]) - } - - mock_get_layers.side_effect = get_layers_side_effect - with pytest.raises(AssertionError): - proposer.load_model(mock_model) - @patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context") @patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context") @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") @@ -209,78 +158,6 @@ class TestMtpProposer: # Check that model was called correct number of times assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") - def test_generate_token_ids(self, mock_cpu_gpu_buffer): - mock_buffer_instance = MagicMock() - mock_cpu_gpu_buffer.return_value = mock_buffer_instance - - mock_deps = MagicMock() - mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput) - mock_deps.scheduler_output.num_scheduled_tokens = 16 - mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata) - mock_deps.spec_decode_metadata.num_draft_tokens = 2 - mock_deps.runner = MagicMock() - mock_deps.runner.input_batch = MagicMock(num_reqs=4) - mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32) - mock_deps.runner.spec_decode_common_attn_metadata = MagicMock() - mock_deps.runner.pcp_size = 2 - mock_deps.runner.dcp_size = 1 - mock_deps.runner.pcp_manager = MagicMock() - mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer( - 32, - dtype=torch.int32, - pin_memory=False, - device='cpu', - ) - mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \ - torch.arange(32, dtype=torch.int32) - mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer( - 5, - dtype=torch.int32, - pin_memory=False, - device='cpu', - ) - mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \ - torch.tensor([0, 8, 16, 24, 32]) - mock_deps.positions = torch.arange(16, dtype=torch.int32) - mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16) - mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1], - [200, -1, -1], - [300, 301, 302]]) - - proposer = MagicMock(spec=MtpProposer) - proposer.enable_shared_expert_dp = False - proposer.runner = mock_deps.runner - proposer.decode_threshold = 1 - proposer.speculative_config = MagicMock( - disable_padded_drafter_batch=False) - proposer.pcp_size = mock_deps.runner.pcp_size - proposer.dcp_size = mock_deps.runner.dcp_size - proposer.prepare_next_token_ids_padded = MagicMock( - return_value=(torch.tensor([101, 200, 302]), 3)) - proposer.prepare_inputs_padded = MagicMock( - return_value=(MagicMock(), torch.tensor([0, 2, 4]), - torch.tensor([7, 15, 23]))) - proposer._propose = MagicMock( - return_value=torch.tensor([400, 401, 402])) - proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__( - proposer) - - draft_token_ids = proposer.generate_token_ids( - sampled_token_ids=mock_deps.sampled_token_ids, - scheduler_output=mock_deps.scheduler_output, - spec_decode_metadata=mock_deps.spec_decode_metadata, - positions=mock_deps.positions, - num_scheduled_tokens=mock_deps.scheduler_output. - num_scheduled_tokens, - hidden_states=mock_deps.hidden_states, - ) - - proposer.prepare_next_token_ids_padded.assert_called_once() - proposer.prepare_inputs_padded.assert_called_once() - proposer._propose.assert_called_once() - assert torch.equal(draft_token_ids, proposer._propose.return_value) - @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer): mock_buffer_instance = MagicMock() diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 30c6396d..d1ce1edf 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -4,7 +4,6 @@ from typing import Optional import numpy as np import torch import torch.nn as nn -from vllm.attention.layer import Attention from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group @@ -13,6 +12,7 @@ from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -109,25 +109,54 @@ class EagleProposer(VllmEagleProposer): def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase).keys()) + target_indexer_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache).keys()) + self.model = get_model(vllm_config=self.vllm_config, model_config=self.vllm_config. speculative_config.draft_model_config) - draft_attn_layer_names = (get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase).keys() - - target_attn_layer_names) - self.attn_layer_name = next(iter(draft_attn_layer_names)) + + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache).keys() + draft_attn_layer = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase).keys() + + draft_attn_layer_names = draft_attn_layer - target_attn_layer_names + draft_indexer_layer_names = indexer_layers - target_indexer_layer_names + draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = list(draft_attn_layer_names) # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - logger.info( - "The EAGLE head shares the same vocab embedding" \ - " with the target model." - ) - self.model.model.embed_tokens = model.model.embed_tokens + if self.method == "mtp": + if self.vllm_config.model_config.is_deepseek_mla and \ + torch.equal(self.model.model.embed_tokens.weight, + model.model.embed_tokens.weight): + # If pp>1, the weights of mtp and the main model's embedding are not on the same device. + # check if mtp model use main model's embedding and LMhead + logger.info( + "The MTP head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = model.model.embed_tokens + else: + logger.info( + " The MTP head loaded its own vocab embedding" \ + " weights instead of sharing them with the target model." + ) + else: + logger.info( + "The EAGLE head shares the same vocab embedding" \ + " with the target model." + ) + self.model.model.embed_tokens = model.model.embed_tokens else: logger.info( - "Since PP > 1, the EAGLE head loaded its own vocab embedding" \ + "Since PP > 1 or other reasons the model head loaded its own vocab embedding" \ " weights instead of sharing them with the target model." ) @@ -141,6 +170,13 @@ class EagleProposer(VllmEagleProposer): else: self.model.lm_head = model.lm_head + if self.method == "mtp" and \ + self.vllm_config.model_config.is_deepseek_mla: + for _, layer_module in self.model.model.layers.items(): + if torch.equal(layer_module.shared_head.head.weight, + model.lm_head.weight): + layer_module.shared_head.head = model.lm_head + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ) and self.use_cuda_graph: self.update_stream = torch.npu.Stream() @@ -205,7 +241,7 @@ class EagleProposer(VllmEagleProposer): attn_metadata_eagle = builder.build_for_graph_capture( common_attn_metadata, AscendAttentionState.ChunkedPrefill) attn_metadata = {} - for layer_name in [self.attn_layer_name]: + for layer_name in self.attn_layer_name: attn_metadata[layer_name] = attn_metadata_eagle for i in range(self.num_speculative_tokens): if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL: @@ -235,135 +271,6 @@ class EagleProposer(VllmEagleProposer): self.vllm_config, ) - def generate_token_ids(self, - sampled_token_ids: torch.Tensor | list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - aux_hidden_states: torch.Tensor = None): - common_attn_metadata = self.runner.spec_decode_common_attn_metadata - - if self.vllm_config.speculative_config.disable_padded_drafter_batch: - # When padded-batch is disabled, the sampled_token_ids should be - # the cpu-side list[list[int]] of valid sampled tokens for each - # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ - "padded-batch is disabled." - next_token_ids = self.prepare_next_token_ids_cpu( - sampled_token_ids, self.runner.requests, - self.runner.input_batch, scheduler_output.num_scheduled_tokens) - else: - # When using padded-batch, the sampled_token_ids should be - # the gpu tensor of sampled tokens for each request, of shape - # (num_reqs, num_spec_tokens + 1) with rejected tokens having - # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ - "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ - self.prepare_next_token_ids_padded( - common_attn_metadata, - sampled_token_ids, - self.runner.requests, - self.runner.input_batch, - self.runner.discard_request_indices.gpu, - self.runner.num_discarded_requests - ) - self._copy_valid_sampled_token_count(next_token_ids, - valid_sampled_tokens_count) - - req_scheduled_tokens = scheduler_output.num_scheduled_tokens - if self.pcp_size > 1: - long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu - query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu - query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu - num_reqs = self.runner.input_batch.num_reqs - ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ - query_start_loc_pcp_full_cpu[:num_reqs] - num_prefill_reqs = (ori_query_lens - > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs - else: - long_seq_metadata = None - num_prefill_reqs = 0 - num_decode_reqs = 0 - if spec_decode_metadata is None: - # update pcp related params - if self.pcp_size > 1: - token_indices_to_sample = \ - query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1 - target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states - else: - token_indices_to_sample = None - # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids.gpu[: - num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - if self.method == "eagle3": - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - else: - if self.pcp_size > 1: - common_attn_metadata.query_start_loc_cpu = \ - query_start_loc_pcp_full_cpu[:num_reqs + 1] - common_attn_metadata.query_start_loc = \ - query_start_loc_pcp_full[:num_reqs + 1] - if self.vllm_config.speculative_config.disable_padded_drafter_batch: - # NOTE: Currently, MTP-fullgraph is incompatibility with pcp - token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) - else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ - self.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) - if self.pcp_size > 1: - target_token_ids = input_ids_pcp_full[token_indices] - target_positions = positions - target_hidden_states = hidden_states - else: - target_token_ids = self.runner.input_ids.gpu[token_indices] - target_positions = positions[token_indices] - if self.method == "eagle3": - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - - draft_token_ids = self._propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - req_scheduled_tokens=req_scheduled_tokens, - long_seq_metadata=long_seq_metadata, - num_prefill_reqs=num_prefill_reqs, - num_decode_reqs=num_decode_reqs, - scheduler_output=scheduler_output, - num_scheduled_tokens=num_scheduled_tokens, - ) - - return draft_token_ids - def _propose( self, # [num_tokens] @@ -430,9 +337,11 @@ class EagleProposer(VllmEagleProposer): self.runner.get_model()) # update global cos, sin update_cos_sin(self.positions[:num_input_tokens]) - + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_name: + per_layer_attn_metadata[layer_name] = attn_metadata with set_ascend_forward_context( - {self.attn_layer_name: attn_metadata}, + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_actual_tokens=num_tokens, @@ -558,7 +467,7 @@ class EagleProposer(VllmEagleProposer): # Run the model. with set_ascend_forward_context( - {self.attn_layer_name: attn_metadata}, + per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size, num_actual_tokens=batch_size, @@ -696,28 +605,6 @@ class EagleProposer(VllmEagleProposer): return next_token_ids, valid_sampled_tokens_count - def _copy_valid_sampled_token_count( - self, next_token_ids: torch.Tensor, - valid_sampled_tokens_count: torch.Tensor) -> None: - if self.runner.valid_sampled_token_count_event is not None: - default_stream = torch.npu.current_stream() - # initialize a new stream to overlap the copy operation with - # prepare_input of draft model. - with torch.npu.stream( - self.runner.valid_sampled_token_count_copy_stream): - self.runner.valid_sampled_token_count_copy_stream.wait_stream( - default_stream) # type: ignore - self.runner.valid_sampled_token_count_cpu[: - valid_sampled_tokens_count - .shape[0]].copy_( - valid_sampled_tokens_count, - non_blocking=True - ) - self.runner.valid_sampled_token_count_event.record() - - self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze( - 1) - def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 08990671..21533d43 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,25 +1,16 @@ -import importlib from typing import Optional, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config import CUDAGraphMode from vllm.distributed import get_pcp_group -from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.utils import \ - process_weights_after_loading -from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available -from vllm.utils.torch_utils import set_default_torch_dtype from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -54,15 +45,6 @@ _MTP_MODELS = { } -def _load_model(architecture): - if architecture not in _MTP_MODELS: - raise ValueError("Invalid architecture for mtp.") - module_name, model_name = _MTP_MODELS[architecture] - module = importlib.import_module(module_name) - model = getattr(module, model_name) - return model - - class MtpProposer(EagleProposer): # TODO: Find out why ModelRunner does not this explicit typing? @@ -86,64 +68,6 @@ class MtpProposer(EagleProposer): update_attn_params(self.update_stream, forward_context, num_tokens, self.vllm_config) - def load_model(self, model) -> None: - loader = get_model_loader(self.vllm_config.load_config) - - target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase).keys()) - target_indexer_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache).keys()) - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config - target_device = self.vllm_config.device_config.device - - with set_default_torch_dtype( - draft_model_config.dtype), set_current_vllm_config( - self.vllm_config): - self._init_mtp_model() - draft_attn_layer_names = (get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase).keys() - - target_attn_layer_names) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache) - draft_indexer_layer_names = indexer_layers.keys( - ) - target_indexer_layer_names - # NOTE: Currently we don't have specific attention backend and attention metadata - # for deepseek v3.2 indexer, so we just exclude the indexer layers here. - draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names - - assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = list(draft_attn_layer_names) - - self.model.load_weights( - loader.get_all_weights( - self.vllm_config.speculative_config.draft_model_config, - self.model)) - process_weights_after_loading(self.model, draft_model_config, - target_device) - - if self.vllm_config.model_config.is_deepseek_mla: - # check if mtp model use main model's embedding and LMhead - main_model = model - if get_pp_group().world_size == 1: - # If pp>1, the weights of mtp and the main model's embedding are not on the same device. - if torch.equal(self.model.model.embed_tokens.weight, - main_model.model.embed_tokens.weight): - self.model.model.embed_tokens = main_model.model.embed_tokens - for _, layer_module in self.model.model.layers.items(): - if torch.equal(layer_module.shared_head.head.weight, - main_model.lm_head.weight): - layer_module.shared_head.head = main_model.lm_head - - if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( - ): - self.update_stream: torch.npu.Stream = torch.npu.Stream() - self.model = ACLGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) - @torch.inference_mode() def dummy_run(self, num_tokens: int, @@ -256,153 +180,6 @@ class MtpProposer(EagleProposer): if with_prefill: break - def generate_token_ids(self, - sampled_token_ids: torch.Tensor | list[list[int]], - sampling_metadata: SamplingMetadata = None, - scheduler_output: SchedulerOutput = None, - spec_decode_metadata: SpecDecodeMetadata = None, - positions: torch.Tensor = None, - num_scheduled_tokens: int = 0, - hidden_states: torch.Tensor = None, - aux_hidden_states: torch.Tensor = None): - common_attn_metadata = self.runner.spec_decode_common_attn_metadata - - if self.speculative_config.disable_padded_drafter_batch: - # When padded-batch is disabled, the sampled_token_ids should be - # the cpu-side list[list[int]] of valid sampled tokens for each - # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ - "padded-batch is disabled." - next_token_ids = self.prepare_next_token_ids_cpu( - sampled_token_ids, self.runner.requests, - self.runner.input_batch, scheduler_output.num_scheduled_tokens) - else: - # When using padded-batch, the sampled_token_ids should be - # the gpu tensor of sampled tokens for each request, of shape - # (num_reqs, num_spec_tokens + 1) with rejected tokens having - # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ - "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ - self.prepare_next_token_ids_padded( - common_attn_metadata, - sampled_token_ids, - self.runner.requests, - self.runner.input_batch, - self.runner.discard_request_indices.gpu, - self.runner.num_discarded_requests - ) - self._copy_valid_sampled_token_count(next_token_ids, - valid_sampled_tokens_count) - - req_scheduled_tokens = scheduler_output.num_scheduled_tokens - if self.pcp_size * self.dcp_size > 1: - long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu - query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu - query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu - num_reqs = self.runner.input_batch.num_reqs - ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ - query_start_loc_pcp_full_cpu[:num_reqs] - num_prefill_reqs = (ori_query_lens - > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs - else: - long_seq_metadata = None - num_prefill_reqs = 0 - num_decode_reqs = 0 - if spec_decode_metadata is None: - # update pcp related params - if self.pcp_size > 1: - token_indices_to_sample = \ - query_start_loc_pcp_full[1:num_reqs + 1] - 1 - target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states - else: - token_indices_to_sample = None - # input_ids can be None for multimodal models. - target_token_ids = self.runner.input_ids.gpu[: - num_scheduled_tokens] - target_positions = positions[:num_scheduled_tokens] - target_hidden_states = hidden_states[:num_scheduled_tokens] - else: - if self.pcp_size > 1: - common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \ - query_start_loc_pcp_full_cpu[:num_reqs + 1] - common_attn_metadata.query_start_loc[:num_reqs + 1] = \ - query_start_loc_pcp_full[:num_reqs + 1] - if self.speculative_config.disable_padded_drafter_batch: - token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self._prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) - else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ - self.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) - if self.pcp_size > 1: - target_token_ids = input_ids_pcp_full[token_indices] - target_positions = positions - target_hidden_states = hidden_states - else: - target_token_ids = self.runner.input_ids.gpu[token_indices] - target_positions = positions[token_indices] - target_hidden_states = hidden_states[token_indices] - - draft_token_ids = self._propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - req_scheduled_tokens=req_scheduled_tokens, - long_seq_metadata=long_seq_metadata, - num_prefill_reqs=num_prefill_reqs, - num_decode_reqs=num_decode_reqs, - scheduler_output=scheduler_output, - num_scheduled_tokens=num_scheduled_tokens, - ) - - return draft_token_ids - - def _copy_valid_sampled_token_count( - self, next_token_ids: torch.Tensor, - valid_sampled_tokens_count: torch.Tensor) -> None: - if self.runner.valid_sampled_token_count_event is not None: - default_stream = torch.npu.current_stream() - # initialize a new stream to overlap the copy operation with - # prepare_input of draft model. - with torch.npu.stream( - self.runner.valid_sampled_token_count_copy_stream): - self.runner.valid_sampled_token_count_copy_stream.wait_stream( - default_stream) # type: ignore - self.runner.valid_sampled_token_count_cpu[: - valid_sampled_tokens_count - .shape[0]].copy_( - valid_sampled_tokens_count, - non_blocking=True - ) - self.runner.valid_sampled_token_count_event.record() - - self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze( - 1) - - def _init_mtp_model(self): - architecture = self.vllm_config.model_config.architecture - target_device = self.vllm_config.device_config.device - model = _load_model(architecture) - self.model = model(vllm_config=self.vllm_config).to(target_device) - def _prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d62f76dc..8017d63a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1,6 +1,6 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. +# Copyright 2025 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,6 +54,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import DeviceMemoryProfiler from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, @@ -113,7 +114,6 @@ from vllm_ascend.worker.pcp_utils import PCPManager from vllm_ascend.ascend_forward_context import ( # isort: skip MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity) - if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -1257,6 +1257,7 @@ class NPUModelRunner(GPUModelRunner): logits_indices=logits_indices, ) + # TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community. def propose_draft_token_ids( self, valid_sampled_token_ids: torch.Tensor | list[list[int]], @@ -1273,10 +1274,147 @@ class NPUModelRunner(GPUModelRunner): # Speculative decoding is not enabled. draft_token_ids = None else: - draft_token_ids = self.drafter.generate_token_ids( - valid_sampled_token_ids, sampling_metadata, scheduler_output, - spec_decode_metadata, positions, num_scheduled_tokens, - hidden_states, aux_hidden_states) + if self.speculative_config.method in ("suffix", "ngram"): + draft_token_ids = self.drafter.generate_token_ids( + valid_sampled_token_ids, sampling_metadata, + scheduler_output, spec_decode_metadata, positions, + num_scheduled_tokens, hidden_states, aux_hidden_states) + + elif self.speculative_config.use_eagle(): + common_attn_metadata = self.spec_decode_common_attn_metadata + sampled_token_ids = valid_sampled_token_ids + + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + assert self.drafter is not None + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + assert self.drafter is not None + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) + self._copy_valid_sampled_token_count( + next_token_ids, valid_sampled_tokens_count) + + req_scheduled_tokens = scheduler_output.num_scheduled_tokens + if self.pcp_size * self.dcp_size > 1: + long_seq_metadata = self.long_seq_metadata # type: ignore + input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu + num_reqs = self.input_batch.num_reqs + ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + query_start_loc_pcp_full_cpu[:num_reqs] + num_prefill_reqs = (ori_query_lens + > self.decode_threshold).sum().item() + num_decode_reqs = num_reqs - num_prefill_reqs + else: + long_seq_metadata = None # type: ignore + num_prefill_reqs = 0 + num_decode_reqs = 0 + if spec_decode_metadata is None: + # update pcp related params + if self.pcp_size > 1: + token_indices_to_sample = \ + query_start_loc_pcp_full[1:num_reqs + 1] - 1 + target_token_ids = input_ids_pcp_full[: + num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states + else: + token_indices_to_sample = None + # input_ids can be None for multimodal models. + target_token_ids = self.input_ids.gpu[: + num_scheduled_tokens] + target_positions = positions[:num_scheduled_tokens] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([ + h[:num_scheduled_tokens] + for h in aux_hidden_states + ], + dim=-1) + else: + target_hidden_states = hidden_states[: + num_scheduled_tokens] + else: + if self.pcp_size > 1: + assert common_attn_metadata is not None + common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \ + query_start_loc_pcp_full_cpu[:num_reqs + 1] + assert common_attn_metadata is not None + common_attn_metadata.query_start_loc[:num_reqs + 1] = \ + query_start_loc_pcp_full[:num_reqs + 1] + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + # NOTE: Currently, MTP-fullgraph is incompatibility with pcp + token_indices_to_sample = None + assert self.drafter is not None + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + assert self.drafter is not None + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + if self.pcp_size > 1: + target_token_ids = input_ids_pcp_full[token_indices] + target_positions = positions + target_hidden_states = hidden_states + else: + target_token_ids = self.input_ids.gpu[token_indices] + target_positions = positions[token_indices] + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[token_indices] for h in aux_hidden_states], + dim=-1) + else: + target_hidden_states = hidden_states[token_indices] + assert self.drafter is not None + draft_token_ids = self.drafter._propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, + scheduler_output=scheduler_output, + num_scheduled_tokens=num_scheduled_tokens, + ) + + else: + raise ValueError("Unknown speculative decoding method: " + f"{self.speculative_config.method}") + return draft_token_ids @staticmethod