[Refactor][EAGLE] 2/N: load model and generate token (#5437)
### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file
2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ut and tests
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -9,7 +9,6 @@ patch
|
|||||||
ModelRunner_prepare_inputs
|
ModelRunner_prepare_inputs
|
||||||
disaggregated_prefill
|
disaggregated_prefill
|
||||||
eplb_swift_balancer.md
|
eplb_swift_balancer.md
|
||||||
Multi_Token_Prediction
|
|
||||||
ACL_Graph
|
ACL_Graph
|
||||||
KV_Cache_Pool_Guide
|
KV_Cache_Pool_Guide
|
||||||
add_custom_aclnn_op
|
add_custom_aclnn_op
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ structured_output
|
|||||||
lora
|
lora
|
||||||
eplb_swift_balancer
|
eplb_swift_balancer
|
||||||
netloader
|
netloader
|
||||||
|
Multi_Token_Prediction
|
||||||
dynamic_batch
|
dynamic_batch
|
||||||
kv_pool
|
kv_pool
|
||||||
external_dp
|
external_dp
|
||||||
|
|||||||
@@ -144,9 +144,17 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
|
def test_load_model_pp1(self, mock_pp_group, mock_get_model,
|
||||||
mock_get_layers):
|
mock_get_layers):
|
||||||
mock_pp_group.return_value.world_size = 1
|
mock_pp_group.return_value.world_size = 1
|
||||||
mock_target_layers = {"layer1": MagicMock(), "layer2": MagicMock()}
|
mock_target_layer1 = MagicMock()
|
||||||
mock_draft_layers = {"layer1": MagicMock(), "layer3": MagicMock()}
|
mock_target_layer2 = MagicMock()
|
||||||
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
|
mock_draft_layer1 = MagicMock()
|
||||||
|
mock_draft_layer3 = MagicMock()
|
||||||
|
mock_get_layers.side_effect = [{
|
||||||
|
"layer1": mock_target_layer1,
|
||||||
|
"layer2": mock_target_layer2
|
||||||
|
}, {}, {}, {
|
||||||
|
"layer1": mock_draft_layer1,
|
||||||
|
"layer3": mock_draft_layer3
|
||||||
|
}]
|
||||||
|
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.model.embed_tokens = MagicMock()
|
mock_model.model.embed_tokens = MagicMock()
|
||||||
@@ -158,7 +166,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
|
|
||||||
self.proposer.load_model(mock_model)
|
self.proposer.load_model(mock_model)
|
||||||
mock_get_model.assert_called_once()
|
mock_get_model.assert_called_once()
|
||||||
self.assertEqual(self.proposer.attn_layer_name, "layer3")
|
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
|
||||||
self.assertIs(self.proposer.model.model.embed_tokens,
|
self.assertIs(self.proposer.model.model.embed_tokens,
|
||||||
mock_model.model.embed_tokens)
|
mock_model.model.embed_tokens)
|
||||||
|
|
||||||
@@ -169,9 +177,14 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
|
def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model,
|
||||||
mock_get_layers):
|
mock_get_layers):
|
||||||
mock_pp_group.return_value.world_size = 2
|
mock_pp_group.return_value.world_size = 2
|
||||||
mock_target_layers = {"layer1": MagicMock()}
|
mock_target_layer1 = MagicMock()
|
||||||
mock_draft_layers = {"layer2": MagicMock()}
|
mock_draft_layer2 = MagicMock()
|
||||||
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
|
|
||||||
|
mock_get_layers.side_effect = [{
|
||||||
|
"layer1": mock_target_layer1
|
||||||
|
}, {}, {}, {
|
||||||
|
"layer2": mock_draft_layer2
|
||||||
|
}]
|
||||||
|
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
original_embed = MagicMock()
|
original_embed = MagicMock()
|
||||||
@@ -184,7 +197,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
|
|
||||||
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
||||||
mock_model.model.embed_tokens)
|
mock_model.model.embed_tokens)
|
||||||
self.assertEqual(self.proposer.attn_layer_name, "layer2")
|
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||||
@@ -200,9 +213,14 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
mock_get_model.return_value = MagicMock(model=MagicMock(
|
mock_get_model.return_value = MagicMock(model=MagicMock(
|
||||||
embed_tokens=original_embed))
|
embed_tokens=original_embed))
|
||||||
|
|
||||||
mock_target_layers = {"layer1": MagicMock()}
|
mock_target_layer1 = MagicMock()
|
||||||
mock_draft_layers = {"layer2": MagicMock()}
|
mock_draft_layer2 = MagicMock()
|
||||||
mock_get_layers.side_effect = [mock_target_layers, mock_draft_layers]
|
|
||||||
|
mock_get_layers.side_effect = [{
|
||||||
|
"layer1": mock_target_layer1
|
||||||
|
}, {}, {}, {
|
||||||
|
"layer2": mock_draft_layer2
|
||||||
|
}]
|
||||||
mock_pp_group.return_value.world_size = 2
|
mock_pp_group.return_value.world_size = 2
|
||||||
|
|
||||||
self.proposer.model = MagicMock()
|
self.proposer.model = MagicMock()
|
||||||
@@ -307,83 +325,6 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.proposer.use_cuda_graph = last_use_cuda_graph
|
self.proposer.use_cuda_graph = last_use_cuda_graph
|
||||||
|
|
||||||
|
|
||||||
class TestEagleProposerGenerateTokenIds(TestBase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.vllm_config = MagicMock(spec=VllmConfig)
|
|
||||||
self.vllm_config.speculative_config = MagicMock()
|
|
||||||
self.vllm_config.speculative_config.method = "eagle"
|
|
||||||
self.device = torch.device("cpu")
|
|
||||||
self.runner = MagicMock()
|
|
||||||
self.runner.input_batch = MagicMock()
|
|
||||||
self.runner.input_batch.req_ids = [0, 1, 2]
|
|
||||||
self.runner.requests = {
|
|
||||||
0: MagicMock(get_token_id=lambda x: 100),
|
|
||||||
1: MagicMock(get_token_id=lambda x: 101),
|
|
||||||
2: MagicMock(get_token_id=lambda x: 102),
|
|
||||||
}
|
|
||||||
self.runner.pcp_size = 1
|
|
||||||
|
|
||||||
self.vllm_config.cache_config.block_size = 16
|
|
||||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
|
||||||
self.vllm_config.scheduler_config.max_num_seqs = 32
|
|
||||||
self.vllm_config.model_config.dtype = torch.float16
|
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
|
||||||
self.vllm_config.speculative_config.num_speculative_tokens = 2
|
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
|
||||||
(i + 1) * (0, ) for i in range(2)
|
|
||||||
])
|
|
||||||
self.vllm_config.additional_config = None
|
|
||||||
init_ascend_config(self.vllm_config)
|
|
||||||
|
|
||||||
self.mock_cpugpubuffer = patch(
|
|
||||||
"vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
||||||
self.mock_cpugpubuffer.start()
|
|
||||||
self.mock_supports_multimodal_inputs = patch(
|
|
||||||
"vllm.multimodal.registry.MultiModalRegistry.supports_multimodal_inputs"
|
|
||||||
)
|
|
||||||
self.mock_supports_multimodal_inputs.start()
|
|
||||||
self.proposer = EagleProposer(vllm_config=self.vllm_config,
|
|
||||||
device=self.device,
|
|
||||||
runner=self.runner)
|
|
||||||
self.proposer.attn_layer_name = "layer_0"
|
|
||||||
self.proposer._propose = MagicMock(
|
|
||||||
return_value=torch.tensor([[1, 2], [3, 4], [5, 6]]))
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.mock_cpugpubuffer.stop()
|
|
||||||
self.mock_supports_multimodal_inputs.stop()
|
|
||||||
|
|
||||||
# TODO: This is equivalent to disable_padded_drafter_batch=True.
|
|
||||||
# We need to add some cases about disable_padded_drafter_batch=False in future.
|
|
||||||
def test_generate_token_ids(self):
|
|
||||||
valid_sampled = [[20, 30, 40]]
|
|
||||||
scheduler_output = MagicMock()
|
|
||||||
scheduler_output.num_scheduled_tokens = [2, 1, 3]
|
|
||||||
positions = torch.tensor([0, 1, 2, 3, 4, 5])
|
|
||||||
hidden_states = torch.randn(6, 4096)
|
|
||||||
num_scheduled = 6
|
|
||||||
|
|
||||||
mock_attn_metadata = MagicMock()
|
|
||||||
mock_attn_metadata.slot_mapping = torch.tensor([0, 1, 2, 3, 4, 5])
|
|
||||||
mock_attn_metadata.query_start_loc = torch.tensor([0, 2, 3, 6])
|
|
||||||
mock_attn_metadata.block_tables = MagicMock()
|
|
||||||
self.proposer._get_eagle_atten_dict = MagicMock(
|
|
||||||
return_value={"layer_0": mock_attn_metadata})
|
|
||||||
|
|
||||||
result = self.proposer.generate_token_ids(
|
|
||||||
sampled_token_ids=valid_sampled,
|
|
||||||
scheduler_output=scheduler_output,
|
|
||||||
positions=positions,
|
|
||||||
num_scheduled_tokens=num_scheduled,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.proposer._propose.assert_called_once()
|
|
||||||
self.assertEqual(result.numpy().tolist(), [[1, 2], [3, 4], [5, 6]])
|
|
||||||
|
|
||||||
|
|
||||||
class TestEagleProposerHelperMethods(TestBase):
|
class TestEagleProposerHelperMethods(TestBase):
|
||||||
|
|
||||||
# TODO: Can add some tests about prepare_next_token_ids in future.
|
# TODO: Can add some tests about prepare_next_token_ids in future.
|
||||||
|
|||||||
@@ -6,12 +6,8 @@ import torch
|
|||||||
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
|
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
|
||||||
ModelConfig, SchedulerConfig, SpeculativeConfig,
|
ModelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VllmConfig)
|
VllmConfig)
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
@@ -107,53 +103,6 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
assert proposer.use_aclgraph is True
|
assert proposer.use_aclgraph is True
|
||||||
|
|
||||||
@patch("vllm.config.get_layers_from_vllm_config")
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
|
|
||||||
@patch(
|
|
||||||
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
|
|
||||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
||||||
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
|
|
||||||
mock_set_dtype, mock_process_weights, mock_get_loader,
|
|
||||||
mock_get_layers, vllm_config, runner):
|
|
||||||
mock_buffer_instance = MagicMock()
|
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
|
||||||
attn_layers_all = {
|
|
||||||
"target_attn_layer": "val0",
|
|
||||||
"draft_attn_layer": "val1",
|
|
||||||
"draft_attn_exclude_by_indexer": "val2",
|
|
||||||
}
|
|
||||||
|
|
||||||
indexer_layers_all = {
|
|
||||||
"target_indexer_0": "val3",
|
|
||||||
"draft_attn_exclude_by_indexer": "val4"
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_layers_side_effect(vllm_config, cache_cls):
|
|
||||||
if cache_cls == AttentionLayerBase:
|
|
||||||
return attn_layers_all
|
|
||||||
elif cache_cls == DeepseekV32IndexerCache:
|
|
||||||
return indexer_layers_all
|
|
||||||
else:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
|
||||||
proposer._init_mtp_model = MagicMock()
|
|
||||||
mock_model = MagicMock()
|
|
||||||
proposer.model = mock_model
|
|
||||||
|
|
||||||
mock_loader = MagicMock()
|
|
||||||
mock_get_loader.return_value = mock_loader
|
|
||||||
mock_loader.get_all_weights.return_value = {
|
|
||||||
"dummy_weight": torch.tensor([1.0])
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_get_layers.side_effect = get_layers_side_effect
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
proposer.load_model(mock_model)
|
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
@@ -209,78 +158,6 @@ class TestMtpProposer:
|
|||||||
# Check that model was called correct number of times
|
# Check that model was called correct number of times
|
||||||
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
|
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
|
||||||
|
|
||||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
|
||||||
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
|
|
||||||
mock_buffer_instance = MagicMock()
|
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
|
||||||
|
|
||||||
mock_deps = MagicMock()
|
|
||||||
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
|
|
||||||
mock_deps.scheduler_output.num_scheduled_tokens = 16
|
|
||||||
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
|
|
||||||
mock_deps.spec_decode_metadata.num_draft_tokens = 2
|
|
||||||
mock_deps.runner = MagicMock()
|
|
||||||
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
|
|
||||||
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
|
|
||||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
|
||||||
mock_deps.runner.pcp_size = 2
|
|
||||||
mock_deps.runner.dcp_size = 1
|
|
||||||
mock_deps.runner.pcp_manager = MagicMock()
|
|
||||||
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
|
|
||||||
32,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=False,
|
|
||||||
device='cpu',
|
|
||||||
)
|
|
||||||
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
|
|
||||||
torch.arange(32, dtype=torch.int32)
|
|
||||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
|
|
||||||
5,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=False,
|
|
||||||
device='cpu',
|
|
||||||
)
|
|
||||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
|
|
||||||
torch.tensor([0, 8, 16, 24, 32])
|
|
||||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
|
||||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
|
||||||
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
|
|
||||||
[200, -1, -1],
|
|
||||||
[300, 301, 302]])
|
|
||||||
|
|
||||||
proposer = MagicMock(spec=MtpProposer)
|
|
||||||
proposer.enable_shared_expert_dp = False
|
|
||||||
proposer.runner = mock_deps.runner
|
|
||||||
proposer.decode_threshold = 1
|
|
||||||
proposer.speculative_config = MagicMock(
|
|
||||||
disable_padded_drafter_batch=False)
|
|
||||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
|
||||||
proposer.dcp_size = mock_deps.runner.dcp_size
|
|
||||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
|
||||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
|
||||||
proposer.prepare_inputs_padded = MagicMock(
|
|
||||||
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
|
|
||||||
torch.tensor([7, 15, 23])))
|
|
||||||
proposer._propose = MagicMock(
|
|
||||||
return_value=torch.tensor([400, 401, 402]))
|
|
||||||
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
|
|
||||||
proposer)
|
|
||||||
|
|
||||||
draft_token_ids = proposer.generate_token_ids(
|
|
||||||
sampled_token_ids=mock_deps.sampled_token_ids,
|
|
||||||
scheduler_output=mock_deps.scheduler_output,
|
|
||||||
spec_decode_metadata=mock_deps.spec_decode_metadata,
|
|
||||||
positions=mock_deps.positions,
|
|
||||||
num_scheduled_tokens=mock_deps.scheduler_output.
|
|
||||||
num_scheduled_tokens,
|
|
||||||
hidden_states=mock_deps.hidden_states,
|
|
||||||
)
|
|
||||||
|
|
||||||
proposer.prepare_next_token_ids_padded.assert_called_once()
|
|
||||||
proposer.prepare_inputs_padded.assert_called_once()
|
|
||||||
proposer._propose.assert_called_once()
|
|
||||||
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
|
||||||
|
|
||||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
|
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@@ -13,6 +12,7 @@ from vllm.logger import logger
|
|||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.models import supports_multimodal
|
from vllm.model_executor.models import supports_multimodal
|
||||||
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
@@ -109,25 +109,54 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
|
|
||||||
def load_model(self, model: nn.Module) -> None:
|
def load_model(self, model: nn.Module) -> None:
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
AttentionLayerBase).keys())
|
||||||
|
target_indexer_layer_names = set(
|
||||||
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
DeepseekV32IndexerCache).keys())
|
||||||
|
|
||||||
self.model = get_model(vllm_config=self.vllm_config,
|
self.model = get_model(vllm_config=self.vllm_config,
|
||||||
model_config=self.vllm_config.
|
model_config=self.vllm_config.
|
||||||
speculative_config.draft_model_config)
|
speculative_config.draft_model_config)
|
||||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
|
||||||
self.vllm_config, AttentionLayerBase).keys() -
|
indexer_layers = get_layers_from_vllm_config(
|
||||||
target_attn_layer_names)
|
self.vllm_config, DeepseekV32IndexerCache).keys()
|
||||||
self.attn_layer_name = next(iter(draft_attn_layer_names))
|
draft_attn_layer = get_layers_from_vllm_config(
|
||||||
|
self.vllm_config, AttentionLayerBase).keys()
|
||||||
|
|
||||||
|
draft_attn_layer_names = draft_attn_layer - target_attn_layer_names
|
||||||
|
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
|
||||||
|
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||||
|
assert len(draft_attn_layer_names) == 1
|
||||||
|
self.attn_layer_name = list(draft_attn_layer_names)
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
logger.info(
|
if self.method == "mtp":
|
||||||
"The EAGLE head shares the same vocab embedding" \
|
if self.vllm_config.model_config.is_deepseek_mla and \
|
||||||
" with the target model."
|
torch.equal(self.model.model.embed_tokens.weight,
|
||||||
)
|
model.model.embed_tokens.weight):
|
||||||
self.model.model.embed_tokens = model.model.embed_tokens
|
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
|
||||||
|
# check if mtp model use main model's embedding and LMhead
|
||||||
|
logger.info(
|
||||||
|
"The MTP head shares the same vocab embedding" \
|
||||||
|
" with the target model."
|
||||||
|
)
|
||||||
|
self.model.model.embed_tokens = model.model.embed_tokens
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
" The MTP head loaded its own vocab embedding" \
|
||||||
|
" weights instead of sharing them with the target model."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"The EAGLE head shares the same vocab embedding" \
|
||||||
|
" with the target model."
|
||||||
|
)
|
||||||
|
self.model.model.embed_tokens = model.model.embed_tokens
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
|
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
|
||||||
" weights instead of sharing them with the target model."
|
" weights instead of sharing them with the target model."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,6 +170,13 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
else:
|
else:
|
||||||
self.model.lm_head = model.lm_head
|
self.model.lm_head = model.lm_head
|
||||||
|
|
||||||
|
if self.method == "mtp" and \
|
||||||
|
self.vllm_config.model_config.is_deepseek_mla:
|
||||||
|
for _, layer_module in self.model.model.layers.items():
|
||||||
|
if torch.equal(layer_module.shared_head.head.weight,
|
||||||
|
model.lm_head.weight):
|
||||||
|
layer_module.shared_head.head = model.lm_head
|
||||||
|
|
||||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||||
) and self.use_cuda_graph:
|
) and self.use_cuda_graph:
|
||||||
self.update_stream = torch.npu.Stream()
|
self.update_stream = torch.npu.Stream()
|
||||||
@@ -205,7 +241,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
for layer_name in [self.attn_layer_name]:
|
for layer_name in self.attn_layer_name:
|
||||||
attn_metadata[layer_name] = attn_metadata_eagle
|
attn_metadata[layer_name] = attn_metadata_eagle
|
||||||
for i in range(self.num_speculative_tokens):
|
for i in range(self.num_speculative_tokens):
|
||||||
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
@@ -235,135 +271,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_token_ids(self,
|
|
||||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
|
||||||
sampling_metadata: SamplingMetadata = None,
|
|
||||||
scheduler_output: SchedulerOutput = None,
|
|
||||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
|
||||||
positions: torch.Tensor = None,
|
|
||||||
num_scheduled_tokens: int = 0,
|
|
||||||
hidden_states: torch.Tensor = None,
|
|
||||||
aux_hidden_states: torch.Tensor = None):
|
|
||||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
|
||||||
|
|
||||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
|
||||||
# When padded-batch is disabled, the sampled_token_ids should be
|
|
||||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
|
||||||
# request, with invalid requests having empty lists.
|
|
||||||
assert isinstance(sampled_token_ids, list), \
|
|
||||||
"sampled_token_ids should be a python list when" \
|
|
||||||
"padded-batch is disabled."
|
|
||||||
next_token_ids = self.prepare_next_token_ids_cpu(
|
|
||||||
sampled_token_ids, self.runner.requests,
|
|
||||||
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
|
|
||||||
else:
|
|
||||||
# When using padded-batch, the sampled_token_ids should be
|
|
||||||
# the gpu tensor of sampled tokens for each request, of shape
|
|
||||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
|
||||||
# value -1.
|
|
||||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
|
||||||
"sampled_token_ids should be a torch.Tensor when" \
|
|
||||||
"padded-batch is enabled."
|
|
||||||
next_token_ids, valid_sampled_tokens_count = \
|
|
||||||
self.prepare_next_token_ids_padded(
|
|
||||||
common_attn_metadata,
|
|
||||||
sampled_token_ids,
|
|
||||||
self.runner.requests,
|
|
||||||
self.runner.input_batch,
|
|
||||||
self.runner.discard_request_indices.gpu,
|
|
||||||
self.runner.num_discarded_requests
|
|
||||||
)
|
|
||||||
self._copy_valid_sampled_token_count(next_token_ids,
|
|
||||||
valid_sampled_tokens_count)
|
|
||||||
|
|
||||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
long_seq_metadata = self.runner.long_seq_metadata
|
|
||||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
|
||||||
num_reqs = self.runner.input_batch.num_reqs
|
|
||||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
|
||||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
|
||||||
num_prefill_reqs = (ori_query_lens
|
|
||||||
> self.decode_threshold).sum().item()
|
|
||||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
|
||||||
else:
|
|
||||||
long_seq_metadata = None
|
|
||||||
num_prefill_reqs = 0
|
|
||||||
num_decode_reqs = 0
|
|
||||||
if spec_decode_metadata is None:
|
|
||||||
# update pcp related params
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
token_indices_to_sample = \
|
|
||||||
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
|
|
||||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
token_indices_to_sample = None
|
|
||||||
# input_ids can be None for multimodal models.
|
|
||||||
target_token_ids = self.runner.input_ids.gpu[:
|
|
||||||
num_scheduled_tokens]
|
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
|
||||||
if self.method == "eagle3":
|
|
||||||
target_hidden_states = torch.cat(
|
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
|
||||||
dim=-1)
|
|
||||||
else:
|
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
||||||
else:
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
common_attn_metadata.query_start_loc_cpu = \
|
|
||||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
|
||||||
common_attn_metadata.query_start_loc = \
|
|
||||||
query_start_loc_pcp_full[:num_reqs + 1]
|
|
||||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
|
||||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
|
||||||
token_indices_to_sample = None
|
|
||||||
common_attn_metadata, token_indices =\
|
|
||||||
self.prepare_inputs(
|
|
||||||
common_attn_metadata,
|
|
||||||
sampled_token_ids,
|
|
||||||
spec_decode_metadata.num_draft_tokens)
|
|
||||||
else:
|
|
||||||
common_attn_metadata, token_indices, \
|
|
||||||
token_indices_to_sample =\
|
|
||||||
self.prepare_inputs_padded(
|
|
||||||
common_attn_metadata,
|
|
||||||
spec_decode_metadata,
|
|
||||||
valid_sampled_tokens_count)
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
target_token_ids = input_ids_pcp_full[token_indices]
|
|
||||||
target_positions = positions
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
|
||||||
target_positions = positions[token_indices]
|
|
||||||
if self.method == "eagle3":
|
|
||||||
target_hidden_states = torch.cat(
|
|
||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
|
||||||
else:
|
|
||||||
target_hidden_states = hidden_states[token_indices]
|
|
||||||
|
|
||||||
draft_token_ids = self._propose(
|
|
||||||
target_token_ids=target_token_ids,
|
|
||||||
target_positions=target_positions,
|
|
||||||
target_hidden_states=target_hidden_states,
|
|
||||||
next_token_ids=next_token_ids,
|
|
||||||
last_token_indices=token_indices_to_sample,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
req_scheduled_tokens=req_scheduled_tokens,
|
|
||||||
long_seq_metadata=long_seq_metadata,
|
|
||||||
num_prefill_reqs=num_prefill_reqs,
|
|
||||||
num_decode_reqs=num_decode_reqs,
|
|
||||||
scheduler_output=scheduler_output,
|
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return draft_token_ids
|
|
||||||
|
|
||||||
def _propose(
|
def _propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
@@ -430,9 +337,11 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.runner.get_model())
|
self.runner.get_model())
|
||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self.positions[:num_input_tokens])
|
update_cos_sin(self.positions[:num_input_tokens])
|
||||||
|
per_layer_attn_metadata = {}
|
||||||
|
for layer_name in self.attn_layer_name:
|
||||||
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
{self.attn_layer_name: attn_metadata},
|
per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_input_tokens,
|
num_tokens=num_input_tokens,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
@@ -558,7 +467,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
{self.attn_layer_name: attn_metadata},
|
per_layer_attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=input_batch_size,
|
num_tokens=input_batch_size,
|
||||||
num_actual_tokens=batch_size,
|
num_actual_tokens=batch_size,
|
||||||
@@ -696,28 +605,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
|
|
||||||
return next_token_ids, valid_sampled_tokens_count
|
return next_token_ids, valid_sampled_tokens_count
|
||||||
|
|
||||||
def _copy_valid_sampled_token_count(
|
|
||||||
self, next_token_ids: torch.Tensor,
|
|
||||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
|
||||||
if self.runner.valid_sampled_token_count_event is not None:
|
|
||||||
default_stream = torch.npu.current_stream()
|
|
||||||
# initialize a new stream to overlap the copy operation with
|
|
||||||
# prepare_input of draft model.
|
|
||||||
with torch.npu.stream(
|
|
||||||
self.runner.valid_sampled_token_count_copy_stream):
|
|
||||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
|
||||||
default_stream) # type: ignore
|
|
||||||
self.runner.valid_sampled_token_count_cpu[:
|
|
||||||
valid_sampled_tokens_count
|
|
||||||
.shape[0]].copy_(
|
|
||||||
valid_sampled_tokens_count,
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.runner.valid_sampled_token_count_event.record()
|
|
||||||
|
|
||||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
|
||||||
1)
|
|
||||||
|
|
||||||
def prepare_inputs(
|
def prepare_inputs(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@@ -1,25 +1,16 @@
|
|||||||
import importlib
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
|
from vllm.config import CUDAGraphMode
|
||||||
set_current_vllm_config)
|
|
||||||
from vllm.distributed import get_pcp_group
|
from vllm.distributed import get_pcp_group
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
|
||||||
from vllm.model_executor.model_loader.utils import \
|
|
||||||
process_weights_after_loading
|
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@@ -54,15 +45,6 @@ _MTP_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _load_model(architecture):
|
|
||||||
if architecture not in _MTP_MODELS:
|
|
||||||
raise ValueError("Invalid architecture for mtp.")
|
|
||||||
module_name, model_name = _MTP_MODELS[architecture]
|
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
model = getattr(module, model_name)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(EagleProposer):
|
class MtpProposer(EagleProposer):
|
||||||
|
|
||||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||||
@@ -86,64 +68,6 @@ class MtpProposer(EagleProposer):
|
|||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
num_tokens, self.vllm_config)
|
num_tokens, self.vllm_config)
|
||||||
|
|
||||||
def load_model(self, model) -> None:
|
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
|
||||||
|
|
||||||
target_attn_layer_names = set(
|
|
||||||
get_layers_from_vllm_config(self.vllm_config,
|
|
||||||
AttentionLayerBase).keys())
|
|
||||||
target_indexer_layer_names = set(
|
|
||||||
get_layers_from_vllm_config(self.vllm_config,
|
|
||||||
DeepseekV32IndexerCache).keys())
|
|
||||||
draft_model_config = \
|
|
||||||
self.vllm_config.speculative_config.draft_model_config
|
|
||||||
target_device = self.vllm_config.device_config.device
|
|
||||||
|
|
||||||
with set_default_torch_dtype(
|
|
||||||
draft_model_config.dtype), set_current_vllm_config(
|
|
||||||
self.vllm_config):
|
|
||||||
self._init_mtp_model()
|
|
||||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
|
||||||
self.vllm_config, AttentionLayerBase).keys() -
|
|
||||||
target_attn_layer_names)
|
|
||||||
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
|
|
||||||
DeepseekV32IndexerCache)
|
|
||||||
draft_indexer_layer_names = indexer_layers.keys(
|
|
||||||
) - target_indexer_layer_names
|
|
||||||
# NOTE: Currently we don't have specific attention backend and attention metadata
|
|
||||||
# for deepseek v3.2 indexer, so we just exclude the indexer layers here.
|
|
||||||
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
|
||||||
|
|
||||||
assert len(draft_attn_layer_names) == 1
|
|
||||||
self.attn_layer_name = list(draft_attn_layer_names)
|
|
||||||
|
|
||||||
self.model.load_weights(
|
|
||||||
loader.get_all_weights(
|
|
||||||
self.vllm_config.speculative_config.draft_model_config,
|
|
||||||
self.model))
|
|
||||||
process_weights_after_loading(self.model, draft_model_config,
|
|
||||||
target_device)
|
|
||||||
|
|
||||||
if self.vllm_config.model_config.is_deepseek_mla:
|
|
||||||
# check if mtp model use main model's embedding and LMhead
|
|
||||||
main_model = model
|
|
||||||
if get_pp_group().world_size == 1:
|
|
||||||
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
|
|
||||||
if torch.equal(self.model.model.embed_tokens.weight,
|
|
||||||
main_model.model.embed_tokens.weight):
|
|
||||||
self.model.model.embed_tokens = main_model.model.embed_tokens
|
|
||||||
for _, layer_module in self.model.model.layers.items():
|
|
||||||
if torch.equal(layer_module.shared_head.head.weight,
|
|
||||||
main_model.lm_head.weight):
|
|
||||||
layer_module.shared_head.head = main_model.lm_head
|
|
||||||
|
|
||||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
|
||||||
):
|
|
||||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
|
||||||
self.model = ACLGraphWrapper(self.model,
|
|
||||||
self.vllm_config,
|
|
||||||
runtime_mode=CUDAGraphMode.FULL)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(self,
|
def dummy_run(self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
@@ -256,153 +180,6 @@ class MtpProposer(EagleProposer):
|
|||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
def generate_token_ids(self,
|
|
||||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
|
||||||
sampling_metadata: SamplingMetadata = None,
|
|
||||||
scheduler_output: SchedulerOutput = None,
|
|
||||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
|
||||||
positions: torch.Tensor = None,
|
|
||||||
num_scheduled_tokens: int = 0,
|
|
||||||
hidden_states: torch.Tensor = None,
|
|
||||||
aux_hidden_states: torch.Tensor = None):
|
|
||||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
|
||||||
|
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
|
||||||
# When padded-batch is disabled, the sampled_token_ids should be
|
|
||||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
|
||||||
# request, with invalid requests having empty lists.
|
|
||||||
assert isinstance(sampled_token_ids, list), \
|
|
||||||
"sampled_token_ids should be a python list when" \
|
|
||||||
"padded-batch is disabled."
|
|
||||||
next_token_ids = self.prepare_next_token_ids_cpu(
|
|
||||||
sampled_token_ids, self.runner.requests,
|
|
||||||
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
|
|
||||||
else:
|
|
||||||
# When using padded-batch, the sampled_token_ids should be
|
|
||||||
# the gpu tensor of sampled tokens for each request, of shape
|
|
||||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
|
||||||
# value -1.
|
|
||||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
|
||||||
"sampled_token_ids should be a torch.Tensor when" \
|
|
||||||
"padded-batch is enabled."
|
|
||||||
next_token_ids, valid_sampled_tokens_count = \
|
|
||||||
self.prepare_next_token_ids_padded(
|
|
||||||
common_attn_metadata,
|
|
||||||
sampled_token_ids,
|
|
||||||
self.runner.requests,
|
|
||||||
self.runner.input_batch,
|
|
||||||
self.runner.discard_request_indices.gpu,
|
|
||||||
self.runner.num_discarded_requests
|
|
||||||
)
|
|
||||||
self._copy_valid_sampled_token_count(next_token_ids,
|
|
||||||
valid_sampled_tokens_count)
|
|
||||||
|
|
||||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
|
||||||
long_seq_metadata = self.runner.long_seq_metadata
|
|
||||||
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
|
|
||||||
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
|
|
||||||
num_reqs = self.runner.input_batch.num_reqs
|
|
||||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
|
||||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
|
||||||
num_prefill_reqs = (ori_query_lens
|
|
||||||
> self.decode_threshold).sum().item()
|
|
||||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
|
||||||
else:
|
|
||||||
long_seq_metadata = None
|
|
||||||
num_prefill_reqs = 0
|
|
||||||
num_decode_reqs = 0
|
|
||||||
if spec_decode_metadata is None:
|
|
||||||
# update pcp related params
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
token_indices_to_sample = \
|
|
||||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
|
||||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
token_indices_to_sample = None
|
|
||||||
# input_ids can be None for multimodal models.
|
|
||||||
target_token_ids = self.runner.input_ids.gpu[:
|
|
||||||
num_scheduled_tokens]
|
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
|
||||||
else:
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
|
||||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
|
||||||
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
|
||||||
query_start_loc_pcp_full[:num_reqs + 1]
|
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
|
||||||
token_indices_to_sample = None
|
|
||||||
common_attn_metadata, token_indices =\
|
|
||||||
self._prepare_inputs(
|
|
||||||
common_attn_metadata,
|
|
||||||
sampled_token_ids,
|
|
||||||
spec_decode_metadata.num_draft_tokens)
|
|
||||||
else:
|
|
||||||
common_attn_metadata, token_indices, \
|
|
||||||
token_indices_to_sample =\
|
|
||||||
self.prepare_inputs_padded(
|
|
||||||
common_attn_metadata,
|
|
||||||
spec_decode_metadata,
|
|
||||||
valid_sampled_tokens_count)
|
|
||||||
if self.pcp_size > 1:
|
|
||||||
target_token_ids = input_ids_pcp_full[token_indices]
|
|
||||||
target_positions = positions
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
|
||||||
target_positions = positions[token_indices]
|
|
||||||
target_hidden_states = hidden_states[token_indices]
|
|
||||||
|
|
||||||
draft_token_ids = self._propose(
|
|
||||||
target_token_ids=target_token_ids,
|
|
||||||
target_positions=target_positions,
|
|
||||||
target_hidden_states=target_hidden_states,
|
|
||||||
next_token_ids=next_token_ids,
|
|
||||||
last_token_indices=token_indices_to_sample,
|
|
||||||
common_attn_metadata=common_attn_metadata,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
req_scheduled_tokens=req_scheduled_tokens,
|
|
||||||
long_seq_metadata=long_seq_metadata,
|
|
||||||
num_prefill_reqs=num_prefill_reqs,
|
|
||||||
num_decode_reqs=num_decode_reqs,
|
|
||||||
scheduler_output=scheduler_output,
|
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return draft_token_ids
|
|
||||||
|
|
||||||
def _copy_valid_sampled_token_count(
|
|
||||||
self, next_token_ids: torch.Tensor,
|
|
||||||
valid_sampled_tokens_count: torch.Tensor) -> None:
|
|
||||||
if self.runner.valid_sampled_token_count_event is not None:
|
|
||||||
default_stream = torch.npu.current_stream()
|
|
||||||
# initialize a new stream to overlap the copy operation with
|
|
||||||
# prepare_input of draft model.
|
|
||||||
with torch.npu.stream(
|
|
||||||
self.runner.valid_sampled_token_count_copy_stream):
|
|
||||||
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
|
|
||||||
default_stream) # type: ignore
|
|
||||||
self.runner.valid_sampled_token_count_cpu[:
|
|
||||||
valid_sampled_tokens_count
|
|
||||||
.shape[0]].copy_(
|
|
||||||
valid_sampled_tokens_count,
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.runner.valid_sampled_token_count_event.record()
|
|
||||||
|
|
||||||
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
|
|
||||||
1)
|
|
||||||
|
|
||||||
def _init_mtp_model(self):
|
|
||||||
architecture = self.vllm_config.model_config.architecture
|
|
||||||
target_device = self.vllm_config.device_config.device
|
|
||||||
model = _load_model(architecture)
|
|
||||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#
|
#
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
# Copyright 2023 The vLLM team.
|
# Copyright 2025 The vLLM team.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@@ -54,6 +54,7 @@ from vllm.utils.math_utils import cdiv
|
|||||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||||
EncoderOnlyAttentionSpec,
|
EncoderOnlyAttentionSpec,
|
||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
@@ -113,7 +114,6 @@ from vllm_ascend.worker.pcp_utils import PCPManager
|
|||||||
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
||||||
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
|
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
|
||||||
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
|
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr # type: ignore[import-untyped]
|
import xgrammar as xgr # type: ignore[import-untyped]
|
||||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||||
@@ -1257,6 +1257,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community.
|
||||||
def propose_draft_token_ids(
|
def propose_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
valid_sampled_token_ids: torch.Tensor | list[list[int]],
|
valid_sampled_token_ids: torch.Tensor | list[list[int]],
|
||||||
@@ -1273,10 +1274,147 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
# Speculative decoding is not enabled.
|
# Speculative decoding is not enabled.
|
||||||
draft_token_ids = None
|
draft_token_ids = None
|
||||||
else:
|
else:
|
||||||
draft_token_ids = self.drafter.generate_token_ids(
|
if self.speculative_config.method in ("suffix", "ngram"):
|
||||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
draft_token_ids = self.drafter.generate_token_ids(
|
||||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
valid_sampled_token_ids, sampling_metadata,
|
||||||
hidden_states, aux_hidden_states)
|
scheduler_output, spec_decode_metadata, positions,
|
||||||
|
num_scheduled_tokens, hidden_states, aux_hidden_states)
|
||||||
|
|
||||||
|
elif self.speculative_config.use_eagle():
|
||||||
|
common_attn_metadata = self.spec_decode_common_attn_metadata
|
||||||
|
sampled_token_ids = valid_sampled_token_ids
|
||||||
|
|
||||||
|
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||||
|
# When padded-batch is disabled, the sampled_token_ids should be
|
||||||
|
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||||
|
# request, with invalid requests having empty lists.
|
||||||
|
assert isinstance(sampled_token_ids, list), \
|
||||||
|
"sampled_token_ids should be a python list when" \
|
||||||
|
"padded-batch is disabled."
|
||||||
|
assert self.drafter is not None
|
||||||
|
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||||
|
sampled_token_ids, self.requests, self.input_batch,
|
||||||
|
scheduler_output.num_scheduled_tokens)
|
||||||
|
else:
|
||||||
|
# When using padded-batch, the sampled_token_ids should be
|
||||||
|
# the gpu tensor of sampled tokens for each request, of shape
|
||||||
|
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||||
|
# value -1.
|
||||||
|
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||||
|
"sampled_token_ids should be a torch.Tensor when" \
|
||||||
|
"padded-batch is enabled."
|
||||||
|
assert self.drafter is not None
|
||||||
|
next_token_ids, valid_sampled_tokens_count = \
|
||||||
|
self.drafter.prepare_next_token_ids_padded(
|
||||||
|
common_attn_metadata,
|
||||||
|
sampled_token_ids,
|
||||||
|
self.requests,
|
||||||
|
self.input_batch,
|
||||||
|
self.discard_request_indices.gpu,
|
||||||
|
self.num_discarded_requests
|
||||||
|
)
|
||||||
|
self._copy_valid_sampled_token_count(
|
||||||
|
next_token_ids, valid_sampled_tokens_count)
|
||||||
|
|
||||||
|
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||||
|
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||||
|
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||||
|
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||||
|
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||||
|
num_prefill_reqs = (ori_query_lens
|
||||||
|
> self.decode_threshold).sum().item()
|
||||||
|
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||||
|
else:
|
||||||
|
long_seq_metadata = None # type: ignore
|
||||||
|
num_prefill_reqs = 0
|
||||||
|
num_decode_reqs = 0
|
||||||
|
if spec_decode_metadata is None:
|
||||||
|
# update pcp related params
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
token_indices_to_sample = \
|
||||||
|
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||||
|
target_token_ids = input_ids_pcp_full[:
|
||||||
|
num_scheduled_tokens]
|
||||||
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
token_indices_to_sample = None
|
||||||
|
# input_ids can be None for multimodal models.
|
||||||
|
target_token_ids = self.input_ids.gpu[:
|
||||||
|
num_scheduled_tokens]
|
||||||
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat([
|
||||||
|
h[:num_scheduled_tokens]
|
||||||
|
for h in aux_hidden_states
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[:
|
||||||
|
num_scheduled_tokens]
|
||||||
|
else:
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
assert common_attn_metadata is not None
|
||||||
|
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
||||||
|
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||||
|
assert common_attn_metadata is not None
|
||||||
|
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
||||||
|
query_start_loc_pcp_full[:num_reqs + 1]
|
||||||
|
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||||
|
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||||
|
token_indices_to_sample = None
|
||||||
|
assert self.drafter is not None
|
||||||
|
common_attn_metadata, token_indices =\
|
||||||
|
self.drafter.prepare_inputs(
|
||||||
|
common_attn_metadata,
|
||||||
|
sampled_token_ids,
|
||||||
|
spec_decode_metadata.num_draft_tokens)
|
||||||
|
else:
|
||||||
|
assert self.drafter is not None
|
||||||
|
common_attn_metadata, token_indices, \
|
||||||
|
token_indices_to_sample =\
|
||||||
|
self.drafter.prepare_inputs_padded(
|
||||||
|
common_attn_metadata,
|
||||||
|
spec_decode_metadata,
|
||||||
|
valid_sampled_tokens_count)
|
||||||
|
if self.pcp_size > 1:
|
||||||
|
target_token_ids = input_ids_pcp_full[token_indices]
|
||||||
|
target_positions = positions
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
target_token_ids = self.input_ids.gpu[token_indices]
|
||||||
|
target_positions = positions[token_indices]
|
||||||
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
target_hidden_states = torch.cat(
|
||||||
|
[h[token_indices] for h in aux_hidden_states],
|
||||||
|
dim=-1)
|
||||||
|
else:
|
||||||
|
target_hidden_states = hidden_states[token_indices]
|
||||||
|
assert self.drafter is not None
|
||||||
|
draft_token_ids = self.drafter._propose(
|
||||||
|
target_token_ids=target_token_ids,
|
||||||
|
target_positions=target_positions,
|
||||||
|
target_hidden_states=target_hidden_states,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
last_token_indices=token_indices_to_sample,
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
req_scheduled_tokens=req_scheduled_tokens,
|
||||||
|
long_seq_metadata=long_seq_metadata,
|
||||||
|
num_prefill_reqs=num_prefill_reqs,
|
||||||
|
num_decode_reqs=num_decode_reqs,
|
||||||
|
scheduler_output=scheduler_output,
|
||||||
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown speculative decoding method: "
|
||||||
|
f"{self.speculative_config.method}")
|
||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user