[MTP][V1] Adapt mtp with graph mode in v1. (#1023)
Adapts deepseek mtp with torch air graph mode in v1. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -100,6 +100,7 @@ class AscendAttentionState(Enum):
|
||||
PrefillCacheHit = 1
|
||||
DecodeOnly = 2
|
||||
ChunkedPrefill = 3
|
||||
SpecDecoding = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
@@ -86,6 +87,7 @@ class AscendMLADecodeMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -169,6 +171,8 @@ class AscendMLAMetadataBuilder:
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -185,16 +189,24 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
# For torch air graph mode we treat spec decoding as decode.
|
||||
if self.torchair_graph_enabled:
|
||||
if num_tokens - num_spec_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
# For eager mode we treat spec decoding as chunked prefill.
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
@@ -284,7 +296,8 @@ class AscendMLAMetadataBuilder:
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
max_seq_lens=1,
|
||||
attn_mask=self.runner.spec_attn_mask)
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -332,7 +345,7 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
query_start_loc = None
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
@@ -397,7 +410,8 @@ class AscendMLAMetadataBuilder:
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=max_seq_lens)
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=self.runner.spec_attn_mask)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
@@ -461,6 +475,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = get_current_vllm_config().speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
@@ -550,7 +569,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
num_tokens = query.size(0)
|
||||
attn_output = None
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
@@ -597,7 +619,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
"Unexpected path reached, AscendMLAImpl should only have PrefillNoCache, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
@@ -670,9 +692,28 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
if self.running_in_graph:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, seq_len, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
q_nope = (q_nope.view(
|
||||
num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1,
|
||||
self.num_heads,
|
||||
-1,
|
||||
).transpose(1, 2).contiguous())
|
||||
q_pe = (q_pe.view(
|
||||
num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1,
|
||||
self.num_heads,
|
||||
-1,
|
||||
).transpose(1, 2).contiguous())
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||
@@ -690,7 +731,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
input_layout="BNSD",
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
atten_mask=spec_attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=self.scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
@@ -732,7 +774,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
self.running_in_graph = self.torchair_graph_enabled and attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
if k_pe is None and not self.running_in_graph:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(
|
||||
|
||||
@@ -203,8 +203,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Set up speculative decoding.
|
||||
self.use_spec_decode = False
|
||||
self.spec_attn_mask = None
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||
2048,
|
||||
dtype=torch.bool),
|
||||
diagonal=1).to("npu")
|
||||
if get_pp_group().is_last_rank:
|
||||
if self.speculative_config.method == "ngram":
|
||||
self.drafter = NgramProposer(self.vllm_config)
|
||||
@@ -779,10 +784,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Get the number of scheduled tokens for each request.
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
|
||||
max_num_scheduled_tokens = 0
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens[i] = num_tokens
|
||||
num_valid_tokens[i] = num_tokens - \
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
|
||||
@@ -838,11 +846,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
|
||||
attn_state = AscendAttentionState.PrefillNoCache
|
||||
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
|
||||
elif np.all(num_scheduled_tokens == 1):
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -873,7 +886,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
with_prefill = attn_state != AscendAttentionState.DecodeOnly
|
||||
with_prefill = attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
@@ -883,14 +898,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Add graph_pad_size here
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
|
||||
and not with_prefill):
|
||||
batch_size = len(seq_lens)
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
else:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
batch_size)
|
||||
graph_pad_size = padded_batch_size - batch_size
|
||||
total_num_scheduled_tokens)
|
||||
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
|
||||
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
|
||||
@@ -4,7 +4,8 @@ from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
@@ -199,6 +200,8 @@ class MtpProposer:
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
|
||||
# TODO Using torch instead of triton may result in poor performance
|
||||
|
||||
Reference in New Issue
Block a user