[Cleanup] Remove unused attn_metadata parameter from Proposer classes (#4862)
The `attn_metadata` is not used by any draft proposer, so we can remove
it.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -238,7 +238,6 @@ class TestMtpProposer:
|
||||
proposer.speculative_config = MagicMock(
|
||||
disable_padded_drafter_batch=False)
|
||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||
proposer._get_attn_metadata = MagicMock(return_value=MagicMock())
|
||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||
proposer.prepare_inputs_padded = MagicMock(
|
||||
@@ -261,7 +260,6 @@ class TestMtpProposer:
|
||||
|
||||
proposer.prepare_next_token_ids_padded.assert_called_once()
|
||||
proposer.prepare_inputs_padded.assert_called_once()
|
||||
proposer._get_attn_metadata.assert_called_once()
|
||||
proposer._propose.assert_called_once()
|
||||
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
||||
|
||||
|
||||
@@ -144,7 +144,6 @@ class EagleProposer(Proposer):
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
|
||||
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
|
||||
|
||||
@@ -48,7 +48,6 @@ class Proposer:
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
"""Called by execute_model in model_runner"""
|
||||
raise NotImplementedError
|
||||
@@ -51,10 +51,6 @@ _MTP_MODELS = {
|
||||
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
||||
}
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
|
||||
def _load_model(architecture):
|
||||
if architecture not in _MTP_MODELS:
|
||||
@@ -345,10 +341,8 @@ class MtpProposer(Proposer):
|
||||
positions: torch.Tensor = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
hidden_states: torch.Tensor = None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||
attn_metadata = self._get_attn_metadata(attn_metadata)
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
@@ -487,14 +481,6 @@ class MtpProposer(Proposer):
|
||||
model = _load_model(architecture)
|
||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@@ -38,7 +38,6 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states=None) -> list[list[int]]:
|
||||
valid_ngram_requests = []
|
||||
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
|
||||
@@ -38,7 +38,6 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
||||
positions=None,
|
||||
num_scheduled_tokens=None,
|
||||
hidden_states=None,
|
||||
attn_metadata=None,
|
||||
aux_hidden_states=None) -> list[list[int]]:
|
||||
draft_token_ids = self.propose(self.runner.input_batch,
|
||||
valid_sampled_token_ids)
|
||||
|
||||
@@ -1383,7 +1383,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
draft_token_ids = self.drafter.generate_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||
hidden_states, attn_metadata, aux_hidden_states)
|
||||
hidden_states, aux_hidden_states)
|
||||
return draft_token_ids
|
||||
|
||||
def _select_moe_comm_method(self,
|
||||
|
||||
Reference in New Issue
Block a user