[Feat][Graph] Support MTP for ACL Graph (#2932)
### What this PR does / why we need it?
This PR depends on the merge of #2707 and has adapted the aclgraph
functionality to support MTP.
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
2b85697031
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -39,7 +39,7 @@ def mtp_correctness(
|
|||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
max_model_len=256,
|
max_model_len=256,
|
||||||
enforce_eager=True) as ref_llm:
|
enforce_eager=False) as ref_llm:
|
||||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
@@ -53,7 +53,7 @@ def mtp_correctness(
|
|||||||
"method": "deepseek_mtp",
|
"method": "deepseek_mtp",
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
},
|
},
|
||||||
enforce_eager=True,
|
enforce_eager=False,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
additional_config={"ascend_scheduler_config": {
|
additional_config={"ascend_scheduler_config": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
|
|||||||
@@ -186,6 +186,34 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
|
ascend_config = MagicMock()
|
||||||
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||||
|
return_value=ascend_config):
|
||||||
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||||
|
mock_device)
|
||||||
|
|
||||||
|
self.assertEqual(builder.block_size,
|
||||||
|
mock_vllm_config.cache_config.block_size)
|
||||||
|
self.assertEqual(
|
||||||
|
builder.chunked_prefill_enabled,
|
||||||
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||||
|
|
||||||
|
def test_ascend_mla_metadata_builder_spec_decode(self):
|
||||||
|
mock_vllm_config = MagicMock()
|
||||||
|
mock_vllm_config.model_config.max_model_len = 1024
|
||||||
|
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||||
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
|
mock_vllm_config.cache_config.block_size = 16
|
||||||
|
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||||
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_spec_config = MagicMock()
|
||||||
|
mock_spec_config.num_speculative_tokens = 3
|
||||||
|
mock_vllm_config.speculative_config = mock_spec_config
|
||||||
|
|
||||||
ascend_config = MagicMock()
|
ascend_config = MagicMock()
|
||||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||||
return_value=ascend_config):
|
return_value=ascend_config):
|
||||||
@@ -208,6 +236,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||||
return_value=ascend_config):
|
return_value=ascend_config):
|
||||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||||
|
|||||||
@@ -190,6 +190,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
ascend_config = MagicMock()
|
ascend_config = MagicMock()
|
||||||
ascend_config.torchair_graph_config = MagicMock()
|
ascend_config.torchair_graph_config = MagicMock()
|
||||||
ascend_config.torchair_graph_config.enabled = True
|
ascend_config.torchair_graph_config.enabled = True
|
||||||
@@ -217,6 +219,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
ascend_config.torchair_graph_config = MagicMock()
|
ascend_config.torchair_graph_config = MagicMock()
|
||||||
ascend_config.torchair_graph_config.enabled = True
|
ascend_config.torchair_graph_config.enabled = True
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||||
mock_vllm_config,
|
mock_vllm_config,
|
||||||
mock_device)
|
mock_device)
|
||||||
@@ -252,6 +256,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||||
return_value=ascend_config):
|
return_value=ascend_config):
|
||||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||||
@@ -288,6 +294,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||||
mock_vllm_config,
|
mock_vllm_config,
|
||||||
mock_device)
|
mock_device)
|
||||||
@@ -309,6 +317,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||||
mock_vllm_config,
|
mock_vllm_config,
|
||||||
mock_device)
|
mock_device)
|
||||||
@@ -331,6 +341,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||||
mock_vllm_config,
|
mock_vllm_config,
|
||||||
mock_device)
|
mock_device)
|
||||||
@@ -357,6 +369,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
mock_vllm_config.model_config.dtype = torch.float16
|
mock_vllm_config.model_config.dtype = torch.float16
|
||||||
mock_device = 'cpu'
|
mock_device = 'cpu'
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(
|
builder = AscendMLATorchairMetadataBuilder(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
@@ -424,6 +438,8 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
|||||||
model = MagicMock(spec=nn.Module)
|
model = MagicMock(spec=nn.Module)
|
||||||
model.model = MagicMock(spec=nn.Module)
|
model.model = MagicMock(spec=nn.Module)
|
||||||
|
|
||||||
|
mock_vllm_config.speculative_config = None
|
||||||
|
|
||||||
builder = AscendMLATorchairMetadataBuilder(
|
builder = AscendMLATorchairMetadataBuilder(
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -187,7 +187,14 @@ class AscendMLAMetadataBuilder:
|
|||||||
self.block_size - 1) // self.block_size
|
self.block_size - 1) // self.block_size
|
||||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||||
|
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
self.decode_threshold = 1
|
self.decode_threshold = 1
|
||||||
|
if self.speculative_config:
|
||||||
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
|
self.decode_threshold += spec_token_num
|
||||||
|
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||||
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||||
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
if self.chunked_prefill_enabled:
|
if self.chunked_prefill_enabled:
|
||||||
self.chunked_prefill_workspace_size = min(
|
self.chunked_prefill_workspace_size = min(
|
||||||
@@ -275,7 +282,6 @@ class AscendMLAMetadataBuilder:
|
|||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
assert num_decodes + num_prefills == num_reqs
|
assert num_decodes + num_prefills == num_reqs
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from torchair import patch_for_hcom
|
|||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||||
set_current_vllm_config)
|
set_current_vllm_config)
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
process_weights_after_loading, set_default_torch_dtype)
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
@@ -363,8 +363,14 @@ class MtpProposer(Proposer):
|
|||||||
not self.runner.with_prefill
|
not self.runner.with_prefill
|
||||||
|
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
|
# Torchair graph mode, padding is same as the main model
|
||||||
num_input_tokens = self.runner.graph_pad_size
|
num_input_tokens = self.runner.graph_pad_size
|
||||||
|
elif (self.runner.use_aclgraph
|
||||||
|
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
||||||
|
# Acl graph mode, add padding to the batch size
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
else:
|
else:
|
||||||
|
# Eager mode, no padding needed
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
|
|
||||||
seq_lens = target_positions[last_token_indices] + 1
|
seq_lens = target_positions[last_token_indices] + 1
|
||||||
@@ -410,7 +416,7 @@ class MtpProposer(Proposer):
|
|||||||
# TODO: adapt enable_dbo later
|
# TODO: adapt enable_dbo later
|
||||||
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
(num_input_tokens, num_tokens_across_dp, with_prefill,
|
||||||
_) = self.runner._sync_metadata_across_dp(
|
_) = self.runner._sync_metadata_across_dp(
|
||||||
num_tokens, self.runner.with_prefill, False)
|
num_input_tokens, self.runner.with_prefill, False)
|
||||||
else:
|
else:
|
||||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||||
@@ -418,6 +424,10 @@ class MtpProposer(Proposer):
|
|||||||
|
|
||||||
moe_comm_method = self.runner._select_moe_comm_method(
|
moe_comm_method = self.runner._select_moe_comm_method(
|
||||||
num_input_tokens, with_prefill)
|
num_input_tokens, with_prefill)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
|
uniform_decode=False)
|
||||||
|
aclgraph_runtime_mode, batch_descriptor = \
|
||||||
|
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||||
|
|
||||||
for step in range(self.num_speculative_tokens):
|
for step in range(self.num_speculative_tokens):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
@@ -428,6 +438,7 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
moe_comm_method=moe_comm_method,
|
moe_comm_method=moe_comm_method,
|
||||||
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=num_tokens):
|
num_actual_tokens=num_tokens):
|
||||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
|
if self.speculative_config:
|
||||||
|
self.actual_seq_lengths_q = list(
|
||||||
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||||
|
self.decode_token_per_req))
|
||||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||||
None, None, vllm_config, device)
|
None, None, vllm_config, device)
|
||||||
|
|
||||||
|
|||||||
@@ -306,17 +306,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.spec_attn_mask = None
|
self.spec_attn_mask = None
|
||||||
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
||||||
MtpProposer]] = None
|
MtpProposer]] = None
|
||||||
self.actual_seq_lengths_q = []
|
self.actual_seq_lengths_q: list[int] = []
|
||||||
self.decode_token_per_req = 1
|
self.decode_token_per_req = 1
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
assert spec_token_num > 0
|
assert spec_token_num > 0
|
||||||
self.decode_token_per_req = 1 + spec_token_num
|
self.decode_token_per_req = 1 + spec_token_num
|
||||||
self.actual_seq_lengths_q = [
|
|
||||||
len for len in
|
|
||||||
range(self.decode_token_per_req, self.max_num_tokens +
|
|
||||||
1, self.decode_token_per_req)
|
|
||||||
]
|
|
||||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||||
2048,
|
2048,
|
||||||
dtype=torch.bool),
|
dtype=torch.bool),
|
||||||
|
|||||||
Reference in New Issue
Block a user