[Refactor][EAGLE] 2/N: load model and generate token (#5437)
### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file
2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ut and tests
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -6,12 +6,8 @@ import torch
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CUDAGraphMode,
|
||||
ModelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
@@ -107,53 +103,6 @@ class TestMtpProposer:
|
||||
|
||||
assert proposer.use_aclgraph is True
|
||||
|
||||
@patch("vllm.config.get_layers_from_vllm_config")
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_model_loader")
|
||||
@patch(
|
||||
"vllm_ascend.spec_decode.mtp_proposer.process_weights_after_loading")
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_default_torch_dtype")
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_current_vllm_config")
|
||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||
def test_load_model(self, mock_cpu_gpu_buffer, mock_set_config,
|
||||
mock_set_dtype, mock_process_weights, mock_get_loader,
|
||||
mock_get_layers, vllm_config, runner):
|
||||
mock_buffer_instance = MagicMock()
|
||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||
attn_layers_all = {
|
||||
"target_attn_layer": "val0",
|
||||
"draft_attn_layer": "val1",
|
||||
"draft_attn_exclude_by_indexer": "val2",
|
||||
}
|
||||
|
||||
indexer_layers_all = {
|
||||
"target_indexer_0": "val3",
|
||||
"draft_attn_exclude_by_indexer": "val4"
|
||||
}
|
||||
|
||||
def get_layers_side_effect(vllm_config, cache_cls):
|
||||
if cache_cls == AttentionLayerBase:
|
||||
return attn_layers_all
|
||||
elif cache_cls == DeepseekV32IndexerCache:
|
||||
return indexer_layers_all
|
||||
else:
|
||||
return {}
|
||||
|
||||
# Setup
|
||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||
proposer._init_mtp_model = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
proposer.model = mock_model
|
||||
|
||||
mock_loader = MagicMock()
|
||||
mock_get_loader.return_value = mock_loader
|
||||
mock_loader.get_all_weights.return_value = {
|
||||
"dummy_weight": torch.tensor([1.0])
|
||||
}
|
||||
|
||||
mock_get_layers.side_effect = get_layers_side_effect
|
||||
with pytest.raises(AssertionError):
|
||||
proposer.load_model(mock_model)
|
||||
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||
@@ -209,78 +158,6 @@ class TestMtpProposer:
|
||||
# Check that model was called correct number of times
|
||||
assert proposer.model.call_count == vllm_config.speculative_config.num_speculative_tokens
|
||||
|
||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||
def test_generate_token_ids(self, mock_cpu_gpu_buffer):
|
||||
mock_buffer_instance = MagicMock()
|
||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||
|
||||
mock_deps = MagicMock()
|
||||
mock_deps.scheduler_output = MagicMock(spec=SchedulerOutput)
|
||||
mock_deps.scheduler_output.num_scheduled_tokens = 16
|
||||
mock_deps.spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
|
||||
mock_deps.spec_decode_metadata.num_draft_tokens = 2
|
||||
mock_deps.runner = MagicMock()
|
||||
mock_deps.runner.input_batch = MagicMock(num_reqs=4)
|
||||
mock_deps.runner.input_ids = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.runner.spec_decode_common_attn_metadata = MagicMock()
|
||||
mock_deps.runner.pcp_size = 2
|
||||
mock_deps.runner.dcp_size = 1
|
||||
mock_deps.runner.pcp_manager = MagicMock()
|
||||
mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer(
|
||||
32,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \
|
||||
torch.arange(32, dtype=torch.int32)
|
||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer(
|
||||
5,
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
device='cpu',
|
||||
)
|
||||
mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \
|
||||
torch.tensor([0, 8, 16, 24, 32])
|
||||
mock_deps.positions = torch.arange(16, dtype=torch.int32)
|
||||
mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16)
|
||||
mock_deps.sampled_token_ids = torch.tensor([[100, 101, -1],
|
||||
[200, -1, -1],
|
||||
[300, 301, 302]])
|
||||
|
||||
proposer = MagicMock(spec=MtpProposer)
|
||||
proposer.enable_shared_expert_dp = False
|
||||
proposer.runner = mock_deps.runner
|
||||
proposer.decode_threshold = 1
|
||||
proposer.speculative_config = MagicMock(
|
||||
disable_padded_drafter_batch=False)
|
||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||
proposer.dcp_size = mock_deps.runner.dcp_size
|
||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||
proposer.prepare_inputs_padded = MagicMock(
|
||||
return_value=(MagicMock(), torch.tensor([0, 2, 4]),
|
||||
torch.tensor([7, 15, 23])))
|
||||
proposer._propose = MagicMock(
|
||||
return_value=torch.tensor([400, 401, 402]))
|
||||
proposer.generate_token_ids = MtpProposer.generate_token_ids.__get__(
|
||||
proposer)
|
||||
|
||||
draft_token_ids = proposer.generate_token_ids(
|
||||
sampled_token_ids=mock_deps.sampled_token_ids,
|
||||
scheduler_output=mock_deps.scheduler_output,
|
||||
spec_decode_metadata=mock_deps.spec_decode_metadata,
|
||||
positions=mock_deps.positions,
|
||||
num_scheduled_tokens=mock_deps.scheduler_output.
|
||||
num_scheduled_tokens,
|
||||
hidden_states=mock_deps.hidden_states,
|
||||
)
|
||||
|
||||
proposer.prepare_next_token_ids_padded.assert_called_once()
|
||||
proposer.prepare_inputs_padded.assert_called_once()
|
||||
proposer._propose.assert_called_once()
|
||||
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
||||
|
||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||
def test_prepare_next_token_ids_cpu(self, mock_cpu_gpu_buffer):
|
||||
mock_buffer_instance = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user