[Model] Support DeepSeek-V4
This commit is contained in:
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
245
vllm_mlu/model_executor/models/layer_utils.py
Executable file
@@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
from typing import Callable, Optional, List, Union, Tuple
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
def hunyuan_decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
smooth_quant_scale = None
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output, ori_kv_states = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
kv_states=kv_states,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = mlp(layernorm_output, residual)
|
||||
return hidden_states, ori_kv_states
|
||||
|
||||
|
||||
def decoder_layer_forward_base(
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_layernorm: Callable,
|
||||
self_attn: Callable,
|
||||
post_layernorm: Callable,
|
||||
mlp: Callable,
|
||||
apply_residual_connection_post_layernorm: bool = False,
|
||||
position_name: str = 'positions',
|
||||
input_norm_fuse_en: bool = False,
|
||||
post_norm_fuse_en: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if input_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
||||
else:
|
||||
layernorm_output = input_layernorm(hidden_states)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# Self Attention
|
||||
attention_output = self_attn(
|
||||
**{position_name: positions},
|
||||
hidden_states=layernorm_output,
|
||||
residual=residual,
|
||||
smooth_quant_scale=smooth_quant_scale,
|
||||
)
|
||||
|
||||
if post_norm_fuse_en:
|
||||
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
|
||||
else:
|
||||
layernorm_output = post_layernorm(attention_output)
|
||||
smooth_quant_scale = None
|
||||
|
||||
if apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
# Fully Connected
|
||||
kwargs = dict()
|
||||
if post_norm_fuse_en:
|
||||
kwargs['smooth_quant_scale'] = smooth_quant_scale
|
||||
hidden_states = mlp(layernorm_output, residual, **kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
layers: torch.nn.ModuleList,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable
|
||||
) -> torch.Tensor:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
for i in range(len(layers)):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def hunyuan_decoder_model_forward_base_pp(
|
||||
config: PretrainedConfig,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
cla_factor = getattr(config, "cla_share_factor", 1)
|
||||
prev_kv_states = None
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states, kv_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
prev_kv_states,
|
||||
)
|
||||
if (i - start_layer) % cla_factor == 0:
|
||||
prev_kv_states = kv_states
|
||||
else:
|
||||
prev_kv_states = None
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decoder_model_forward_base_pp(
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
layers: torch.nn.ModuleList,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
embed_input_ids: Callable,
|
||||
norm: Callable,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = embed_input_ids(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
for i in range(start_layer, end_layer):
|
||||
layer = layers[i]
|
||||
hidden_states = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
})
|
||||
|
||||
hidden_states = norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (quant_config is not None
|
||||
and quant_config.get_name() == "SmoothQuant")
|
||||
|
||||
|
||||
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
|
||||
return (is_smoothquant(quant_config)
|
||||
and quant_config.input_quant_method == "per_token")
|
||||
|
||||
def compute_in_loop(func: Callable,
|
||||
input: torch.Tensor,
|
||||
chunk_size: int,
|
||||
feature_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
divides input into chunks in the leading dimension (dimension 0), and
|
||||
compute the chunks in a loop, instead of in a batch at once.
|
||||
|
||||
arg:
|
||||
feature_size: size of output feature dimension. Provide it when the
|
||||
the output's feature dimension would differ from the input's
|
||||
feature dimension.
|
||||
"""
|
||||
|
||||
total = input.shape[0]
|
||||
# directly compute if there is only one chunk
|
||||
if chunk_size >= total:
|
||||
return func(input, **kwargs)
|
||||
|
||||
feature_size = feature_size or input.shape[1]
|
||||
output = input.new_empty(total, feature_size)
|
||||
num_chunks = (total + chunk_size - 1) // chunk_size
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = min((i + 1) * chunk_size, total)
|
||||
output[start : end] = func(input[start : end], **kwargs)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user