port deepseekv2 and mtp to main branch (#429)

### What this PR does / why we need it?
This PR ports all the deepseek graph mode code and mtp code from v0.7.3
to the main branch
---------

Signed-off-by: SidaoY <1024863041@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: mengwei805 <mengwei25@huawei.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: q00832892 <qiaoyang19@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: SidaoY <1024863041@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
Pleaplusone
2025-04-19 17:38:18 +08:00
committed by GitHub
parent 086423dc35
commit 1a1f9a6d89
33 changed files with 3361 additions and 315 deletions

View File

@@ -18,18 +18,21 @@
#
import dataclasses
import itertools
import weakref
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
Type, TypeVar, Union)
import numpy as np
import torch
import torch.nn as nn
import torch_npu
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
@@ -53,7 +56,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists,
is_pin_memory_available)
is_pin_memory_available, supports_dynamo)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@@ -72,6 +75,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
TModelInputForNPU = TypeVar('TModelInputForNPU', bound="ModelInputForNPU")
ENCODER_NUM = 0
@dataclass(frozen=True)
@@ -526,6 +530,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
seq_lens = []
max_decode_seq_len = 0
is_prompt = self.inter_data_list[0].is_prompt
for inter_data in self.inter_data_list:
seq_lens.extend(inter_data.seq_lens)
if not inter_data.is_prompt:
@@ -540,7 +545,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
for data in self.inter_data_list
}
input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens),
# Add graph_pad_size here
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
graph_pad_size = self.runner.scheduler_config.max_num_seqs - len(
seq_lens)
else:
graph_pad_size = -1
#print(f"before tensor input_tokens: {input_tokens}")
#print(f"before tensor input_positions: {input_positions}")
#print(f"before list seq_lens: {seq_lens}")
input_tokens = flatten_2d_lists(input_tokens)
input_positions = flatten_2d_lists(input_positions)
if graph_pad_size != -1 and not is_prompt:
input_tokens.extend(itertools.repeat(0, graph_pad_size))
input_positions.extend( # type: ignore
itertools.repeat(0, graph_pad_size))
seq_lens.extend(itertools.repeat(1, graph_pad_size))
query_lens.extend(itertools.repeat(1, graph_pad_size))
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
if mrope_input_positions is not None:
@@ -548,13 +572,16 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
dtype=torch.long,
device=self.runner.device)
else:
input_positions_tensor = torch.tensor(
flatten_2d_lists(input_positions),
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.runner.device)
#print(f"after tensor input_tokens_tensor: {input_tokens_tensor}")
#print(f"after tensor input_positions_tensor: {input_positions_tensor}")
#print(f"after list seq_lens: {seq_lens}")
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens)
attn_metadata = self.attn_metadata_builder.build(
seq_lens, query_lens, graph_pad_size)
# LoRA data.
lora_requests = set()
@@ -582,6 +609,13 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
]
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
if self.runner.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
torch._dynamo.mark_static(input_tokens_tensor)
torch._dynamo.mark_static(input_positions_tensor)
torch._dynamo.mark_static(attn_metadata.block_tables)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
@@ -841,6 +875,12 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
self.in_profile_run = False
self.graph_block_tables = np.zeros(
(self.vllm_config.scheduler_config.max_num_seqs,
(model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
@@ -930,6 +970,26 @@ class NPUModelRunnerBase(ModelRunnerBase[TModelInputForNPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
# adapter torch compile with npu_backend
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
import torchair # type: ignore
from torchair import patch_for_hcom # type: ignore
# 通信算子成图
patch_for_hcom()
# 设置npu的config如果不设置config可以使用默认的那可以设置npu_backend="npu"
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
torch.npu.set_compile_mode(jit_compile=False)
self.compile_model = torchair.inference.cache_compile(
self.model.forward,
dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
def save_sharded_state(
self,
path: str,
@@ -1219,10 +1279,43 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
torch._dynamo.mark_static(model_input.input_tokens)
torch._dynamo.mark_static(model_input.input_positions)
torch._dynamo.mark_static(model_input.attn_metadata.block_tables)
torch._dynamo.mark_static(model_input.attn_metadata.slot_mapping)
torch._dynamo.mark_static(
model_input.attn_metadata.query_start_loc)
torch._dynamo.mark_static(model_input.attn_metadata.seq_start_loc)
for kv in kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
# TODO(andoorve): We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
model_executable = self.model
prefill_meta = model_input.attn_metadata.prefill_metadata
previous_hidden_states = kwargs.get("previous_hidden_states")
if prefill_meta is None and self.vllm_config.compilation_config.level > 0:
model_executable = self.compile_model
# Note: graph_batch_size value not same as GPU
graph_batch_size = model_input.input_tokens.shape[ # type: ignore
0] # type: ignore
# Note: previous_hidden_states maybe None not same as GPU
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
torch.empty([
graph_batch_size - previous_hidden_states.shape[0],
*previous_hidden_states.shape[1:]
],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
])
else:
model_executable = self.model
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
@@ -1248,8 +1341,11 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_inner_state else {}
previous_hidden_states = kwargs.get("previous_hidden_states")
model_kwargs = {}
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
model_kwargs = {"inputs_embeds": None}
else:
model_kwargs = {}
if previous_hidden_states is not None:
model_kwargs["previous_hidden_states"] = previous_hidden_states
@@ -1273,44 +1369,30 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
**seqlen_agnostic_kwargs,
**model_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors["model_forward_time"] = (
torch.tensor(model_forward_time + orig_model_forward_time))
return hidden_or_intermediate_states
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
if (self.is_driver_worker
and hidden_or_intermediate_states is not None
and isinstance(hidden_or_intermediate_states,
IntermediateTensors)
and self.observability_config is not None and
self.observability_config.collect_model_forward_time):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
hidden_or_intermediate_states.tensors[
"model_forward_time"] = (
torch.tensor(model_forward_time +
orig_model_forward_time))
return hidden_or_intermediate_states
# TODO: remove the synchronize here
torch.npu.synchronize()
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)
if not self.is_driver_worker:
return []
@@ -1348,6 +1430,9 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_states = hidden_or_intermediate_states.index_select(
0, indices)
output.prefill_hidden_states = hidden_or_intermediate_states
elif self.vllm_config.compilation_config.level == \
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
hidden_states = hidden_or_intermediate_states[:len(indices)]
else:
hidden_states = hidden_or_intermediate_states