[Refactor][EAGLE] 4/N extract common methods from eagle and mtp (#5870)
### What this PR does / why we need it?
This PR aims to extract common methods from eagle_proposer and
mtp_proposer. This is a small step towards merging eagle and mtp.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -165,7 +165,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
|
|
||||||
self.proposer.load_model(mock_model)
|
self.proposer.load_model(mock_model)
|
||||||
mock_get_model.assert_called_once()
|
mock_get_model.assert_called_once()
|
||||||
self.assertEqual(self.proposer.attn_layer_name, ["layer3"])
|
self.assertEqual(self.proposer.attn_layer_names, ["layer3"])
|
||||||
self.assertIs(self.proposer.model.model.embed_tokens,
|
self.assertIs(self.proposer.model.model.embed_tokens,
|
||||||
mock_model.model.embed_tokens)
|
mock_model.model.embed_tokens)
|
||||||
|
|
||||||
@@ -196,7 +196,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
|
|
||||||
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
self.assertIsNot(self.proposer.model.model.embed_tokens,
|
||||||
mock_model.model.embed_tokens)
|
mock_model.model.embed_tokens)
|
||||||
self.assertEqual(self.proposer.attn_layer_name, ["layer2"])
|
self.assertEqual(self.proposer.attn_layer_names, ["layer2"])
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
"vllm_ascend.spec_decode.eagle_proposer.get_layers_from_vllm_config")
|
||||||
@@ -239,6 +239,8 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.vllm_config.speculative_config.num_speculative_tokens = 4
|
self.vllm_config.speculative_config.num_speculative_tokens = 4
|
||||||
self.device = torch.device("cpu")
|
self.device = torch.device("cpu")
|
||||||
self.runner = MagicMock()
|
self.runner = MagicMock()
|
||||||
|
self.runner.pcp_size = 1
|
||||||
|
self.runner.dcp_size = 1
|
||||||
|
|
||||||
self.vllm_config.cache_config.block_size = 16
|
self.vllm_config.cache_config.block_size = 16
|
||||||
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
self.vllm_config.scheduler_config.max_num_batched_tokens = 1024
|
||||||
@@ -246,6 +248,7 @@ class TestEagleProposerDummyRun(TestBase):
|
|||||||
self.vllm_config.model_config.dtype = torch.float16
|
self.vllm_config.model_config.dtype = torch.float16
|
||||||
self.vllm_config.model_config.max_model_len = 2048
|
self.vllm_config.model_config.max_model_len = 2048
|
||||||
self.vllm_config.model_config.uses_mrope = False
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
|
self.vllm_config.model_config.use_mla = False
|
||||||
self.vllm_config.speculative_config.speculative_token_tree = str([
|
self.vllm_config.speculative_config.speculative_token_tree = str([
|
||||||
(i + 1) * (0, ) for i in range(4)
|
(i + 1) * (0, ) for i in range(4)
|
||||||
])
|
])
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class TestMtpProposer:
|
|||||||
config.additional_config = None
|
config.additional_config = None
|
||||||
config.speculative_config = MagicMock(spec=SpeculativeConfig)
|
config.speculative_config = MagicMock(spec=SpeculativeConfig)
|
||||||
config.speculative_config.num_speculative_tokens = 2
|
config.speculative_config.num_speculative_tokens = 2
|
||||||
config.speculative_config.method = "deepseek_mtp"
|
config.speculative_config.method = "mtp"
|
||||||
config.speculative_config.draft_model_config = MagicMock()
|
config.speculative_config.draft_model_config = MagicMock()
|
||||||
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
||||||
config.speculative_config.speculative_token_tree = str([
|
config.speculative_config.speculative_token_tree = str([
|
||||||
@@ -98,9 +98,11 @@ class TestMtpProposer:
|
|||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
mock_cpu_gpu_buffer.return_value = mock_buffer_instance
|
||||||
runner._use_aclgraph.return_value = True
|
runner._use_aclgraph.return_value = True
|
||||||
|
vllm_config.scheduler_config.async_scheduling = False
|
||||||
|
vllm_config.speculative_config.enforce_eager = False
|
||||||
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
proposer = MtpProposer(vllm_config, torch.device("cpu"), runner)
|
||||||
|
|
||||||
assert proposer.use_aclgraph is True
|
assert proposer.use_cuda_graph is True
|
||||||
|
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.get_forward_context")
|
||||||
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
@patch("vllm_ascend.spec_decode.mtp_proposer.set_ascend_forward_context")
|
||||||
|
|||||||
@@ -91,23 +91,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
super().__init__(vllm_config, device, runner)
|
super().__init__(vllm_config, device, runner)
|
||||||
|
|
||||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||||
# there is synchronization between mtp steps when enabling aclgraph,
|
|
||||||
# disable aclgraph when use async scheduling to avoid the
|
|
||||||
# synchronization overhead.
|
|
||||||
# NOTE: we need to set aclgraph_runtime_mode to None in both dummy_run
|
|
||||||
# and _propose.
|
|
||||||
self.use_cuda_graph = (
|
|
||||||
self.vllm_config.compilation_config.mode
|
|
||||||
== CompilationMode.VLLM_COMPILE
|
|
||||||
and not self.vllm_config.model_config.enforce_eager
|
|
||||||
and not self.use_async_scheduling
|
|
||||||
and not self.vllm_config.speculative_config.enforce_eager)
|
|
||||||
|
|
||||||
self.cudagraph_batch_sizes = list(
|
|
||||||
sorted(
|
|
||||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
|
||||||
|
|
||||||
self.pcp_size = self.runner.pcp_size
|
|
||||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||||
self.query_start_loc = self.runner._make_buffer(
|
self.query_start_loc = self.runner._make_buffer(
|
||||||
self.runner.max_num_reqs + 1, dtype=torch.int32)
|
self.runner.max_num_reqs + 1, dtype=torch.int32)
|
||||||
@@ -118,12 +102,11 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
|
|
||||||
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
||||||
|
|
||||||
|
self.pcp_size = self.runner.pcp_size
|
||||||
self.dcp_size = self.runner.dcp_size
|
self.dcp_size = self.runner.dcp_size
|
||||||
self.pcp_rank = self.runner.pcp_rank
|
self.pcp_rank = self.runner.pcp_rank
|
||||||
self.dcp_rank = self.runner.dcp_rank
|
self.dcp_rank = self.runner.dcp_rank
|
||||||
|
|
||||||
self.use_aclgraph = self.runner._use_aclgraph()
|
|
||||||
|
|
||||||
self.full_indices = range(
|
self.full_indices = range(
|
||||||
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||||
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||||
@@ -131,6 +114,10 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||||
"index_topk")
|
"index_topk")
|
||||||
|
|
||||||
|
self.use_cuda_graph = (self.runner._use_aclgraph()
|
||||||
|
and not self.speculative_config.enforce_eager
|
||||||
|
and not self.use_async_scheduling)
|
||||||
|
|
||||||
# TODO: Remove it when the bug of fx-graph is solved
|
# TODO: Remove it when the bug of fx-graph is solved
|
||||||
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
self.maybe_eager_context: ContextManager[Any] = nullcontext()
|
||||||
if not self.use_cuda_graph and enable_sp(vllm_config):
|
if not self.use_cuda_graph and enable_sp(vllm_config):
|
||||||
@@ -158,8 +145,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
|
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
|
||||||
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
|
||||||
assert len(draft_attn_layer_names) == 1
|
assert len(draft_attn_layer_names) == 1
|
||||||
self.attn_layer_name = list(draft_attn_layer_names)
|
self.attn_layer_names = list(draft_attn_layer_names)
|
||||||
self.attn_layer_names = self.attn_layer_name
|
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
@@ -273,7 +259,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_name:
|
for layer_name in self.attn_layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_eagle
|
attn_metadata[layer_name] = attn_metadata_eagle
|
||||||
|
|
||||||
model_input_ids = self.input_ids[:num_tokens]
|
model_input_ids = self.input_ids[:num_tokens]
|
||||||
@@ -292,30 +278,22 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
is_draft_model=True):
|
is_draft_model=True):
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
model_previous_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||||
if forward_context.sp_enabled:
|
model_previous_hidden_states, model_positions)
|
||||||
model_previous_hidden_states = split_inputs_tp_to_sp(
|
|
||||||
model_previous_hidden_states,
|
|
||||||
model_previous_hidden_states)
|
|
||||||
|
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_previous_hidden_states,
|
hidden_states=model_previous_hidden_states,
|
||||||
)
|
)
|
||||||
|
forward_context = get_forward_context()
|
||||||
if (forward_context.cudagraph_runtime_mode
|
if (forward_context.cudagraph_runtime_mode
|
||||||
== CUDAGraphMode.FULL
|
== CUDAGraphMode.FULL
|
||||||
and not forward_context.capturing):
|
and not forward_context.capturing):
|
||||||
update_attn_params(
|
self._update_full_graph_params(forward_context, num_tokens)
|
||||||
self.update_stream,
|
|
||||||
forward_context,
|
|
||||||
num_tokens,
|
|
||||||
self.vllm_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_context.sp_enabled:
|
model_previous_hidden_states, model_positions, _ = self.maybe_all_gather_and_unpad(
|
||||||
model_previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
model_previous_hidden_states, model_positions)
|
||||||
model_previous_hidden_states, True)
|
|
||||||
|
|
||||||
dummy_compute_logits(self.hidden_states)
|
dummy_compute_logits(self.hidden_states)
|
||||||
|
|
||||||
@@ -362,7 +340,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and \
|
||||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
@@ -386,7 +364,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self.positions[:num_input_tokens])
|
update_cos_sin(self.positions[:num_input_tokens])
|
||||||
per_layer_attn_metadata = {}
|
per_layer_attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_name:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
per_layer_attn_metadata,
|
per_layer_attn_metadata,
|
||||||
@@ -403,34 +381,27 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
model_positions = self.positions[:num_input_tokens]
|
model_positions = self.positions[:num_input_tokens]
|
||||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||||
if forward_context.sp_enabled:
|
model_hidden_states, model_positions)
|
||||||
# split hidden states along sequence dimension
|
|
||||||
# positions should not be split?
|
|
||||||
model_hidden_states = split_inputs_tp_to_sp(
|
|
||||||
model_hidden_states, model_hidden_states)
|
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
)
|
)
|
||||||
|
if self.method == "mtp":
|
||||||
|
last_hidden_states = ret_hidden_states
|
||||||
|
hidden_states = last_hidden_states
|
||||||
|
else:
|
||||||
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
# TODO: support mla in future.
|
self._update_full_graph_params(forward_context,
|
||||||
update_attn_params(
|
num_input_tokens)
|
||||||
self.update_stream,
|
|
||||||
forward_context,
|
|
||||||
num_input_tokens,
|
|
||||||
self.vllm_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_context.sp_enabled:
|
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||||
# merge hidden states along sequence dimension
|
last_hidden_states, model_positions, hidden_states)
|
||||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
last_hidden_states.contiguous(), True)
|
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
hidden_states.contiguous(), True)
|
|
||||||
|
|
||||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
@@ -453,7 +424,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
last_token_indices = self.arange[:batch_size]
|
last_token_indices = self.arange[:batch_size]
|
||||||
|
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and \
|
||||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
batch_size <= self.runner.cudagraph_batch_sizes[-1]:
|
||||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size
|
||||||
@@ -556,32 +527,27 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
model_positions = self.positions[:input_batch_size]
|
model_positions = self.positions[:input_batch_size]
|
||||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||||
|
|
||||||
forward_context = get_forward_context()
|
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||||
if forward_context.sp_enabled:
|
model_hidden_states, model_positions)
|
||||||
# split hidden states along sequence dimension
|
|
||||||
# positions should not be split?
|
|
||||||
model_hidden_states = split_inputs_tp_to_sp(
|
|
||||||
model_hidden_states, model_hidden_states)
|
|
||||||
|
|
||||||
last_hidden_states, hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
)
|
)
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if self.method == "mtp":
|
||||||
update_attn_params(
|
last_hidden_states = ret_hidden_states
|
||||||
self.update_stream,
|
hidden_states = last_hidden_states
|
||||||
forward_context,
|
else:
|
||||||
input_batch_size,
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
self.vllm_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_context.sp_enabled:
|
forward_context = get_forward_context()
|
||||||
# merge hidden states along sequence dimension
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
self._update_full_graph_params(forward_context,
|
||||||
last_hidden_states.contiguous(), True)
|
input_batch_size)
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
hidden_states.contiguous(), True)
|
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
|
||||||
|
last_hidden_states, model_positions, hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||||
@@ -948,3 +914,46 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
else:
|
else:
|
||||||
update_attn_params(self.update_stream, forward_context,
|
update_attn_params(self.update_stream, forward_context,
|
||||||
num_tokens, self.vllm_config)
|
num_tokens, self.vllm_config)
|
||||||
|
|
||||||
|
def maybe_pad_and_reduce(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.method == "mtp":
|
||||||
|
if self.enable_shared_expert_dp:
|
||||||
|
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||||
|
hidden_states)
|
||||||
|
positions = positions.unsqueeze(-1)
|
||||||
|
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
||||||
|
positions = positions.squeeze(-1)
|
||||||
|
else:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if forward_context.sp_enabled:
|
||||||
|
hidden_states = split_inputs_tp_to_sp(
|
||||||
|
hidden_states, hidden_states)
|
||||||
|
return hidden_states, positions
|
||||||
|
|
||||||
|
def maybe_all_gather_and_unpad(
|
||||||
|
self,
|
||||||
|
last_hidden_states: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||||
|
if self.method == "mtp":
|
||||||
|
if self.enable_shared_expert_dp:
|
||||||
|
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
last_hidden_states.contiguous(), True)
|
||||||
|
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
positions.contiguous(), True)
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states = last_hidden_states
|
||||||
|
else:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if forward_context.sp_enabled:
|
||||||
|
last_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
last_hidden_states.contiguous(), True)
|
||||||
|
if hidden_states is not None:
|
||||||
|
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
hidden_states.contiguous(), True)
|
||||||
|
return last_hidden_states, positions, hidden_states
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class MtpProposer(EagleProposer):
|
|||||||
attn_metadata_mtp = builder.build_for_graph_capture(
|
attn_metadata_mtp = builder.build_for_graph_capture(
|
||||||
common_attn_metadata, attn_state)
|
common_attn_metadata, attn_state)
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_name:
|
for layer_name in self.attn_layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_mtp
|
attn_metadata[layer_name] = attn_metadata_mtp
|
||||||
else:
|
else:
|
||||||
attn_metadata = None
|
attn_metadata = None
|
||||||
@@ -112,12 +112,8 @@ class MtpProposer(EagleProposer):
|
|||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
is_draft_model=True,
|
is_draft_model=True,
|
||||||
in_profile_run=is_profile):
|
in_profile_run=is_profile):
|
||||||
if self.enable_shared_expert_dp:
|
previous_hidden_states, positions = self.maybe_pad_and_reduce(
|
||||||
positions = positions.unsqueeze(-1)
|
previous_hidden_states, positions)
|
||||||
positions = torch.ops.vllm.maybe_pad_and_reduce(positions)
|
|
||||||
positions = positions.squeeze(-1)
|
|
||||||
previous_hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
|
||||||
previous_hidden_states)
|
|
||||||
self.model(input_ids=input_ids,
|
self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=previous_hidden_states)
|
hidden_states=previous_hidden_states)
|
||||||
@@ -126,11 +122,8 @@ class MtpProposer(EagleProposer):
|
|||||||
not forward_context.capturing and not self.use_sparse:
|
not forward_context.capturing and not self.use_sparse:
|
||||||
self._update_full_graph_params(forward_context, num_tokens)
|
self._update_full_graph_params(forward_context, num_tokens)
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
previous_hidden_states, positions, _ = self.maybe_all_gather_and_unpad(
|
||||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
previous_hidden_states, positions)
|
||||||
positions, True)
|
|
||||||
previous_hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
previous_hidden_states, True)
|
|
||||||
dummy_compute_logits(previous_hidden_states)
|
dummy_compute_logits(previous_hidden_states)
|
||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
@@ -249,11 +242,11 @@ class MtpProposer(EagleProposer):
|
|||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# Note(qcs): We may need to refactor these check logics.
|
# Note(qcs): We may need to refactor these check logics.
|
||||||
if self.runner.use_aclgraph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[
|
if self.use_cuda_graph and num_scheduled_tokens <= self.runner.cudagraph_batch_sizes[
|
||||||
-1]:
|
-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||||
num_scheduled_tokens)
|
num_scheduled_tokens)
|
||||||
elif self.use_aclgraph and num_tokens <= self.runner.cudagraph_batch_sizes[
|
elif self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[
|
||||||
-1]:
|
-1]:
|
||||||
# Acl graph mode, add padding to the batch size
|
# Acl graph mode, add padding to the batch size
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
@@ -304,7 +297,7 @@ class MtpProposer(EagleProposer):
|
|||||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||||
self.runner.get_model())
|
self.runner.get_model())
|
||||||
attn_metadata = {}
|
attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_name:
|
for layer_name in self.attn_layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_mtp
|
attn_metadata[layer_name] = attn_metadata_mtp
|
||||||
|
|
||||||
for step in range(self.num_speculative_tokens):
|
for step in range(self.num_speculative_tokens):
|
||||||
@@ -324,26 +317,8 @@ class MtpProposer(EagleProposer):
|
|||||||
positions = self.positions[:num_input_tokens]
|
positions = self.positions[:num_input_tokens]
|
||||||
hidden_states = self.hidden_states[:num_input_tokens]
|
hidden_states = self.hidden_states[:num_input_tokens]
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
hidden_states, positions = self.maybe_pad_and_reduce(
|
||||||
# positions [N] -> [N, 1] for padding
|
hidden_states, positions)
|
||||||
positions = positions.unsqueeze(-1)
|
|
||||||
positions = torch.ops.vllm.maybe_pad_and_reduce(
|
|
||||||
positions)
|
|
||||||
positions = positions.squeeze(-1)
|
|
||||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
|
||||||
hidden_states)
|
|
||||||
|
|
||||||
for layer_name in self.attn_layer_name:
|
|
||||||
decode_metadata = getattr(attn_metadata[layer_name],
|
|
||||||
"decode", None)
|
|
||||||
if self.use_async_scheduling and decode_metadata is not None:
|
|
||||||
actual_size = len(
|
|
||||||
decode_metadata.actual_seq_lengths_q)
|
|
||||||
|
|
||||||
decode_metadata.seq_lens_list = \
|
|
||||||
decode_metadata.seq_lens_list[:actual_size]
|
|
||||||
decode_metadata.block_table = \
|
|
||||||
decode_metadata.block_table[:actual_size]
|
|
||||||
|
|
||||||
hidden_states = self.model(input_ids=input_ids,
|
hidden_states = self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -353,11 +328,8 @@ class MtpProposer(EagleProposer):
|
|||||||
self._update_full_graph_params(forward_context,
|
self._update_full_graph_params(forward_context,
|
||||||
num_input_tokens)
|
num_input_tokens)
|
||||||
|
|
||||||
if self.enable_shared_expert_dp:
|
hidden_states, positions, _ = self.maybe_all_gather_and_unpad(
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
hidden_states, positions)
|
||||||
hidden_states.contiguous(), True)
|
|
||||||
positions = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
positions.contiguous(), True)
|
|
||||||
|
|
||||||
num_indices = last_token_indices.shape[0]
|
num_indices = last_token_indices.shape[0]
|
||||||
if lmhead_tp_enable():
|
if lmhead_tp_enable():
|
||||||
@@ -398,7 +370,7 @@ class MtpProposer(EagleProposer):
|
|||||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
|
attn_metadata_i = attn_metadata[self.attn_layer_names[0]]
|
||||||
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
|
|||||||
Reference in New Issue
Block a user