245 lines
7.5 KiB
Python
Executable File
245 lines
7.5 KiB
Python
Executable File
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
import torch
|
|
from typing import Callable, Optional, List, Union, Tuple
|
|
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.distributed import get_pp_group
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
def hunyuan_decoder_layer_forward_base(
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_layernorm: Callable,
|
|
self_attn: Callable,
|
|
post_layernorm: Callable,
|
|
mlp: Callable,
|
|
kv_states: Optional[Tuple[torch.Tensor]] = None,
|
|
apply_residual_connection_post_layernorm: bool = False,
|
|
position_name: str = 'positions',
|
|
input_norm_fuse_en: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
smooth_quant_scale = None
|
|
if input_norm_fuse_en:
|
|
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
|
else:
|
|
layernorm_output = input_layernorm(hidden_states)
|
|
smooth_quant_scale = None
|
|
if apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = hidden_states
|
|
|
|
# Self Attention
|
|
attention_output, ori_kv_states = self_attn(
|
|
**{position_name: positions},
|
|
hidden_states=layernorm_output,
|
|
residual=residual,
|
|
kv_states=kv_states,
|
|
smooth_quant_scale=smooth_quant_scale,
|
|
)
|
|
|
|
layernorm_output = post_layernorm(attention_output)
|
|
if apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = attention_output
|
|
|
|
# Fully Connected
|
|
hidden_states = mlp(layernorm_output, residual)
|
|
return hidden_states, ori_kv_states
|
|
|
|
|
|
def decoder_layer_forward_base(
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
input_layernorm: Callable,
|
|
self_attn: Callable,
|
|
post_layernorm: Callable,
|
|
mlp: Callable,
|
|
apply_residual_connection_post_layernorm: bool = False,
|
|
position_name: str = 'positions',
|
|
input_norm_fuse_en: bool = False,
|
|
post_norm_fuse_en: bool = False,
|
|
) -> torch.Tensor:
|
|
if input_norm_fuse_en:
|
|
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
|
|
else:
|
|
layernorm_output = input_layernorm(hidden_states)
|
|
smooth_quant_scale = None
|
|
|
|
if apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = hidden_states
|
|
|
|
# Self Attention
|
|
attention_output = self_attn(
|
|
**{position_name: positions},
|
|
hidden_states=layernorm_output,
|
|
residual=residual,
|
|
smooth_quant_scale=smooth_quant_scale,
|
|
)
|
|
|
|
if post_norm_fuse_en:
|
|
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
|
|
else:
|
|
layernorm_output = post_layernorm(attention_output)
|
|
smooth_quant_scale = None
|
|
|
|
if apply_residual_connection_post_layernorm:
|
|
residual = layernorm_output
|
|
else:
|
|
residual = attention_output
|
|
|
|
# Fully Connected
|
|
kwargs = dict()
|
|
if post_norm_fuse_en:
|
|
kwargs['smooth_quant_scale'] = smooth_quant_scale
|
|
hidden_states = mlp(layernorm_output, residual, **kwargs)
|
|
return hidden_states
|
|
|
|
|
|
def decoder_model_forward_base(
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
layers: torch.nn.ModuleList,
|
|
embed_input_ids: Callable,
|
|
norm: Callable
|
|
) -> torch.Tensor:
|
|
hidden_states = embed_input_ids(input_ids)
|
|
for i in range(len(layers)):
|
|
layer = layers[i]
|
|
hidden_states = layer(
|
|
positions,
|
|
hidden_states,
|
|
)
|
|
hidden_states = norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def hunyuan_decoder_model_forward_base_pp(
|
|
config: PretrainedConfig,
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
layers: torch.nn.ModuleList,
|
|
start_layer: int,
|
|
end_layer: int,
|
|
embed_input_ids: Callable,
|
|
norm: Callable,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = embed_input_ids(input_ids)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
cla_factor = getattr(config, "cla_share_factor", 1)
|
|
prev_kv_states = None
|
|
for i in range(start_layer, end_layer):
|
|
layer = layers[i]
|
|
hidden_states, kv_states = layer(
|
|
positions,
|
|
hidden_states,
|
|
prev_kv_states,
|
|
)
|
|
if (i - start_layer) % cla_factor == 0:
|
|
prev_kv_states = kv_states
|
|
else:
|
|
prev_kv_states = None
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
})
|
|
|
|
hidden_states = norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def decoder_model_forward_base_pp(
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
layers: torch.nn.ModuleList,
|
|
start_layer: int,
|
|
end_layer: int,
|
|
embed_input_ids: Callable,
|
|
norm: Callable,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = embed_input_ids(input_ids)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
for i in range(start_layer, end_layer):
|
|
layer = layers[i]
|
|
hidden_states = layer(
|
|
positions,
|
|
hidden_states,
|
|
)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
})
|
|
|
|
hidden_states = norm(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
|
|
return (quant_config is not None
|
|
and quant_config.get_name() == "SmoothQuant")
|
|
|
|
|
|
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
|
|
return (is_smoothquant(quant_config)
|
|
and quant_config.input_quant_method == "per_token")
|
|
|
|
def compute_in_loop(func: Callable,
|
|
input: torch.Tensor,
|
|
chunk_size: int,
|
|
feature_size: Optional[int] = None,
|
|
**kwargs):
|
|
"""
|
|
divides input into chunks in the leading dimension (dimension 0), and
|
|
compute the chunks in a loop, instead of in a batch at once.
|
|
|
|
arg:
|
|
feature_size: size of output feature dimension. Provide it when the
|
|
the output's feature dimension would differ from the input's
|
|
feature dimension.
|
|
"""
|
|
|
|
total = input.shape[0]
|
|
# directly compute if there is only one chunk
|
|
if chunk_size >= total:
|
|
return func(input, **kwargs)
|
|
|
|
feature_size = feature_size or input.shape[1]
|
|
output = input.new_empty(total, feature_size)
|
|
num_chunks = (total + chunk_size - 1) // chunk_size
|
|
|
|
for i in range(num_chunks):
|
|
start = i * chunk_size
|
|
end = min((i + 1) * chunk_size, total)
|
|
output[start : end] = func(input[start : end], **kwargs)
|
|
|
|
return output |