[Feat]Make full graph mode compalible with MTP (#3276)
### What this PR does / why we need it? Make the Full Graph mode can run with MTP. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -245,7 +245,8 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
@@ -260,9 +261,19 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape):
|
||||
seq_lens_list, actual_seq_lengths, workspace, attn_output,
|
||||
softmax_lse) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
|
||||
if speculative_config and speculative_config.method == "deepseek_mtp":
|
||||
actual_seq_lengths = forward_context.attn_metadata[
|
||||
key].decode.actual_seq_lengths_q
|
||||
spec_multiple = speculative_config.num_speculative_tokens + 1
|
||||
seq_lens_list = seq_lens_list + [0] * (
|
||||
runtime_shape // spec_multiple - len(seq_lens_list))
|
||||
actual_seq_lengths = [
|
||||
spec_multiple * (i + 1)
|
||||
for i in range(runtime_shape // spec_multiple)
|
||||
]
|
||||
else:
|
||||
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
|
||||
len(seq_lens_list))
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user