port deepseekv2 and mtp to main branch (#429)
### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
@@ -19,36 +19,77 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Adapted from
|
||||
# vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
from typing import Optional, Union
|
||||
# <<<<<<< HEAD
|
||||
# # Adapted from
|
||||
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
||||
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
# from typing import Optional, Union
|
||||
|
||||
# import torch
|
||||
# from torch import nn
|
||||
# from transformers import PretrainedConfig
|
||||
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
# from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
# from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
# from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
# from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
# from vllm.model_executor.layers.sampler import get_sampler
|
||||
# from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
# ParallelLMHead, VocabParallelEmbedding)
|
||||
# from vllm.model_executor.models.deepseek_v2 import ( # noqa
|
||||
# DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||
# DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
|
||||
# =======
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_v2 import ( # noqa
|
||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2MoE)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # ruff: noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention,
|
||||
DeepseekV2MLP)
|
||||
from vllm.model_executor.models.utils import (
|
||||
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
# >>>>>>> dcd5c73 (Feat: Graph mode for deepseek v2/v3.)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
top_k: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,10 +97,15 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.n_routed_experts}.")
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@@ -76,7 +122,7 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.experts = FusedMoE(
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -99,9 +145,248 @@ class CustomDeepseekV2MoE(DeepseekV2MoE):
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
reduce_results=True,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
|
||||
self.tp_group = get_tp_group().device_group
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
return hidden_states
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if (self.tp_size > 1 and self.enable_mc2
|
||||
and attn_metadata.num_prefills == 0):
|
||||
# hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
# hidden_states, "sum", scatter_dim=0, group=self.tp_group
|
||||
# )
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
hidden_states = chunks[get_tp_group().rank_in_group]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
is_prefill = True if attn_metadata.num_prefills > 0 else False
|
||||
# is_prefill = attn_metadata.num_prefills > 0
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.enable_mc2 and not is_prefill:
|
||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||
final_hidden_states, self.tp_group)
|
||||
final_hidden_states = self.final_hidden_states
|
||||
else:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
if VLLM_ENABLE_GRAPH_MODE == "1":
|
||||
self.forward = self.forward_torchair
|
||||
else:
|
||||
self.forward = self.forward_eager # type: ignore
|
||||
|
||||
def forward_torchair(self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor = None,
|
||||
attn_metadata=None):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
kv_cache, attn_metadata)
|
||||
|
||||
def forward_eager(self, positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# positions: torch.Tensor,
|
||||
# hidden_states: torch.Tensor,
|
||||
# # torchair should pass below two parameters
|
||||
# kv_cache: torch.Tensor = None,
|
||||
# attn_metadata: AttentionMetadata = None,
|
||||
# ) -> torch.Tensor:
|
||||
# if self.q_lora_rank is not None:
|
||||
# ckq = self.q_a_proj(hidden_states)[0]
|
||||
# hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
# else:
|
||||
# hidden_states_or_q_c = hidden_states
|
||||
# if VLLM_ENABLE_GRAPH_MODE == '1':
|
||||
# return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
# kv_cache, attn_metadata)
|
||||
# else:
|
||||
# kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
# [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
# kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
# return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, output_shape=hidden_states.shape)
|
||||
# kv_cache, attn_metadata)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -124,8 +409,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = DeepseekV2MLAAttention
|
||||
attn_cls = CustomDeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
@@ -180,8 +466,8 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
|
||||
Reference in New Issue
Block a user