Files
2026-04-24 09:58:03 +08:00

245 lines
7.5 KiB
Python
Executable File

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
from typing import Callable, Optional, List, Union, Tuple
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.attention import AttentionMetadata
from vllm.sequence import IntermediateTensors
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from transformers import PretrainedConfig
def hunyuan_decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
kv_states: Optional[Tuple[torch.Tensor]] = None,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
smooth_quant_scale = None
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output, ori_kv_states = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
kv_states=kv_states,
smooth_quant_scale=smooth_quant_scale,
)
layernorm_output = post_layernorm(attention_output)
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
hidden_states = mlp(layernorm_output, residual)
return hidden_states, ori_kv_states
def decoder_layer_forward_base(
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_layernorm: Callable,
self_attn: Callable,
post_layernorm: Callable,
mlp: Callable,
apply_residual_connection_post_layernorm: bool = False,
position_name: str = 'positions',
input_norm_fuse_en: bool = False,
post_norm_fuse_en: bool = False,
) -> torch.Tensor:
if input_norm_fuse_en:
layernorm_output, smooth_quant_scale = input_layernorm(hidden_states)
else:
layernorm_output = input_layernorm(hidden_states)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self Attention
attention_output = self_attn(
**{position_name: positions},
hidden_states=layernorm_output,
residual=residual,
smooth_quant_scale=smooth_quant_scale,
)
if post_norm_fuse_en:
layernorm_output, smooth_quant_scale = post_layernorm(attention_output)
else:
layernorm_output = post_layernorm(attention_output)
smooth_quant_scale = None
if apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# Fully Connected
kwargs = dict()
if post_norm_fuse_en:
kwargs['smooth_quant_scale'] = smooth_quant_scale
hidden_states = mlp(layernorm_output, residual, **kwargs)
return hidden_states
def decoder_model_forward_base(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
layers: torch.nn.ModuleList,
embed_input_ids: Callable,
norm: Callable
) -> torch.Tensor:
hidden_states = embed_input_ids(input_ids)
for i in range(len(layers)):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
hidden_states = norm(hidden_states)
return hidden_states
def hunyuan_decoder_model_forward_base_pp(
config: PretrainedConfig,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
cla_factor = getattr(config, "cla_share_factor", 1)
prev_kv_states = None
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states, kv_states = layer(
positions,
hidden_states,
prev_kv_states,
)
if (i - start_layer) % cla_factor == 0:
prev_kv_states = kv_states
else:
prev_kv_states = None
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def decoder_model_forward_base_pp(
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
layers: torch.nn.ModuleList,
start_layer: int,
end_layer: int,
embed_input_ids: Callable,
norm: Callable,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(start_layer, end_layer):
layer = layers[i]
hidden_states = layer(
positions,
hidden_states,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
})
hidden_states = norm(hidden_states)
return hidden_states
def is_smoothquant(quant_config: QuantizationConfig) -> bool:
return (quant_config is not None
and quant_config.get_name() == "SmoothQuant")
def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool:
return (is_smoothquant(quant_config)
and quant_config.input_quant_method == "per_token")
def compute_in_loop(func: Callable,
input: torch.Tensor,
chunk_size: int,
feature_size: Optional[int] = None,
**kwargs):
"""
divides input into chunks in the leading dimension (dimension 0), and
compute the chunks in a loop, instead of in a batch at once.
arg:
feature_size: size of output feature dimension. Provide it when the
the output's feature dimension would differ from the input's
feature dimension.
"""
total = input.shape[0]
# directly compute if there is only one chunk
if chunk_size >= total:
return func(input, **kwargs)
feature_size = feature_size or input.shape[1]
output = input.new_empty(total, feature_size)
num_chunks = (total + chunk_size - 1) // chunk_size
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, total)
output[start : end] = func(input[start : end], **kwargs)
return output