[Cleanup] Remove unused attn_metadata parameter from Proposer classes (#4862)
The `attn_metadata` is not used by any draft proposer, so we can remove
it.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -238,7 +238,6 @@ class TestMtpProposer:
|
|||||||
proposer.speculative_config = MagicMock(
|
proposer.speculative_config = MagicMock(
|
||||||
disable_padded_drafter_batch=False)
|
disable_padded_drafter_batch=False)
|
||||||
proposer.pcp_size = mock_deps.runner.pcp_size
|
proposer.pcp_size = mock_deps.runner.pcp_size
|
||||||
proposer._get_attn_metadata = MagicMock(return_value=MagicMock())
|
|
||||||
proposer.prepare_next_token_ids_padded = MagicMock(
|
proposer.prepare_next_token_ids_padded = MagicMock(
|
||||||
return_value=(torch.tensor([101, 200, 302]), 3))
|
return_value=(torch.tensor([101, 200, 302]), 3))
|
||||||
proposer.prepare_inputs_padded = MagicMock(
|
proposer.prepare_inputs_padded = MagicMock(
|
||||||
@@ -261,7 +260,6 @@ class TestMtpProposer:
|
|||||||
|
|
||||||
proposer.prepare_next_token_ids_padded.assert_called_once()
|
proposer.prepare_next_token_ids_padded.assert_called_once()
|
||||||
proposer.prepare_inputs_padded.assert_called_once()
|
proposer.prepare_inputs_padded.assert_called_once()
|
||||||
proposer._get_attn_metadata.assert_called_once()
|
|
||||||
proposer._propose.assert_called_once()
|
proposer._propose.assert_called_once()
|
||||||
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
assert torch.equal(draft_token_ids, proposer._propose.return_value)
|
||||||
|
|
||||||
|
|||||||
@@ -144,7 +144,6 @@ class EagleProposer(Proposer):
|
|||||||
positions: torch.Tensor = None,
|
positions: torch.Tensor = None,
|
||||||
num_scheduled_tokens: int = 0,
|
num_scheduled_tokens: int = 0,
|
||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
attn_metadata=None,
|
|
||||||
aux_hidden_states: torch.Tensor = None):
|
aux_hidden_states: torch.Tensor = None):
|
||||||
|
|
||||||
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
|
attn_metadata = self._get_eagle_atten_dict(scheduler_output)
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ class Proposer:
|
|||||||
positions: torch.Tensor = None,
|
positions: torch.Tensor = None,
|
||||||
num_scheduled_tokens: int = 0,
|
num_scheduled_tokens: int = 0,
|
||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
attn_metadata=None,
|
|
||||||
aux_hidden_states: torch.Tensor = None):
|
aux_hidden_states: torch.Tensor = None):
|
||||||
"""Called by execute_model in model_runner"""
|
"""Called by execute_model in model_runner"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -51,10 +51,6 @@ _MTP_MODELS = {
|
|||||||
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
||||||
}
|
}
|
||||||
|
|
||||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
|
||||||
|
|
||||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
|
||||||
|
|
||||||
|
|
||||||
def _load_model(architecture):
|
def _load_model(architecture):
|
||||||
if architecture not in _MTP_MODELS:
|
if architecture not in _MTP_MODELS:
|
||||||
@@ -345,10 +341,8 @@ class MtpProposer(Proposer):
|
|||||||
positions: torch.Tensor = None,
|
positions: torch.Tensor = None,
|
||||||
num_scheduled_tokens: int = 0,
|
num_scheduled_tokens: int = 0,
|
||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
attn_metadata=None,
|
|
||||||
aux_hidden_states: torch.Tensor = None):
|
aux_hidden_states: torch.Tensor = None):
|
||||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||||
attn_metadata = self._get_attn_metadata(attn_metadata)
|
|
||||||
|
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
# When padded-batch is disabled, the sampled_token_ids should be
|
# When padded-batch is disabled, the sampled_token_ids should be
|
||||||
@@ -487,14 +481,6 @@ class MtpProposer(Proposer):
|
|||||||
model = _load_model(architecture)
|
model = _load_model(architecture)
|
||||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||||
|
|
||||||
def _get_attn_metadata(self, attn_metadata):
|
|
||||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
|
||||||
architecture = self.vllm_config.model_config.architecture
|
|
||||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
|
||||||
attn_metadata = attn_metadata[layer_name]
|
|
||||||
|
|
||||||
return attn_metadata
|
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
|||||||
positions=None,
|
positions=None,
|
||||||
num_scheduled_tokens=None,
|
num_scheduled_tokens=None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
attn_metadata=None,
|
|
||||||
aux_hidden_states=None) -> list[list[int]]:
|
aux_hidden_states=None) -> list[list[int]]:
|
||||||
valid_ngram_requests = []
|
valid_ngram_requests = []
|
||||||
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
for i, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
|
|||||||
positions=None,
|
positions=None,
|
||||||
num_scheduled_tokens=None,
|
num_scheduled_tokens=None,
|
||||||
hidden_states=None,
|
hidden_states=None,
|
||||||
attn_metadata=None,
|
|
||||||
aux_hidden_states=None) -> list[list[int]]:
|
aux_hidden_states=None) -> list[list[int]]:
|
||||||
draft_token_ids = self.propose(self.runner.input_batch,
|
draft_token_ids = self.propose(self.runner.input_batch,
|
||||||
valid_sampled_token_ids)
|
valid_sampled_token_ids)
|
||||||
|
|||||||
@@ -1383,7 +1383,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
draft_token_ids = self.drafter.generate_token_ids(
|
draft_token_ids = self.drafter.generate_token_ids(
|
||||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||||
hidden_states, attn_metadata, aux_hidden_states)
|
hidden_states, aux_hidden_states)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def _select_moe_comm_method(self,
|
def _select_moe_comm_method(self,
|
||||||
|
|||||||
Reference in New Issue
Block a user