# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch from typing import Callable, Optional, List, Union, Tuple from vllm_mlu import _mlu_ops as mlu_ops from vllm.attention import AttentionMetadata from vllm.sequence import IntermediateTensors from vllm.distributed import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from transformers import PretrainedConfig def hunyuan_decoder_layer_forward_base( positions: torch.Tensor, hidden_states: torch.Tensor, input_layernorm: Callable, self_attn: Callable, post_layernorm: Callable, mlp: Callable, kv_states: Optional[Tuple[torch.Tensor]] = None, apply_residual_connection_post_layernorm: bool = False, position_name: str = 'positions', input_norm_fuse_en: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: smooth_quant_scale = None if input_norm_fuse_en: layernorm_output, smooth_quant_scale = input_layernorm(hidden_states) else: layernorm_output = input_layernorm(hidden_states) smooth_quant_scale = None if apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self Attention attention_output, ori_kv_states = self_attn( **{position_name: positions}, hidden_states=layernorm_output, residual=residual, kv_states=kv_states, smooth_quant_scale=smooth_quant_scale, ) layernorm_output = post_layernorm(attention_output) if apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # Fully Connected hidden_states = mlp(layernorm_output, residual) return hidden_states, ori_kv_states def decoder_layer_forward_base( positions: torch.Tensor, hidden_states: torch.Tensor, input_layernorm: Callable, self_attn: Callable, post_layernorm: Callable, mlp: Callable, apply_residual_connection_post_layernorm: bool = False, position_name: str = 'positions', input_norm_fuse_en: bool = False, post_norm_fuse_en: bool = False, ) -> torch.Tensor: if input_norm_fuse_en: layernorm_output, smooth_quant_scale = input_layernorm(hidden_states) else: layernorm_output = input_layernorm(hidden_states) smooth_quant_scale = None if apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self Attention attention_output = self_attn( **{position_name: positions}, hidden_states=layernorm_output, residual=residual, smooth_quant_scale=smooth_quant_scale, ) if post_norm_fuse_en: layernorm_output, smooth_quant_scale = post_layernorm(attention_output) else: layernorm_output = post_layernorm(attention_output) smooth_quant_scale = None if apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # Fully Connected kwargs = dict() if post_norm_fuse_en: kwargs['smooth_quant_scale'] = smooth_quant_scale hidden_states = mlp(layernorm_output, residual, **kwargs) return hidden_states def decoder_model_forward_base( input_ids: Optional[torch.Tensor], positions: torch.Tensor, layers: torch.nn.ModuleList, embed_input_ids: Callable, norm: Callable ) -> torch.Tensor: hidden_states = embed_input_ids(input_ids) for i in range(len(layers)): layer = layers[i] hidden_states = layer( positions, hidden_states, ) hidden_states = norm(hidden_states) return hidden_states def hunyuan_decoder_model_forward_base_pp( config: PretrainedConfig, input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], layers: torch.nn.ModuleList, start_layer: int, end_layer: int, embed_input_ids: Callable, norm: Callable, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] cla_factor = getattr(config, "cla_share_factor", 1) prev_kv_states = None for i in range(start_layer, end_layer): layer = layers[i] hidden_states, kv_states = layer( positions, hidden_states, prev_kv_states, ) if (i - start_layer) % cla_factor == 0: prev_kv_states = kv_states else: prev_kv_states = None if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, }) hidden_states = norm(hidden_states) return hidden_states def decoder_model_forward_base_pp( input_ids: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], layers: torch.nn.ModuleList, start_layer: int, end_layer: int, embed_input_ids: Callable, norm: Callable, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = embed_input_ids(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] for i in range(start_layer, end_layer): layer = layers[i] hidden_states = layer( positions, hidden_states, ) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, }) hidden_states = norm(hidden_states) return hidden_states def is_smoothquant(quant_config: QuantizationConfig) -> bool: return (quant_config is not None and quant_config.get_name() == "SmoothQuant") def is_per_token_smoothquant(quant_config: QuantizationConfig) -> bool: return (is_smoothquant(quant_config) and quant_config.input_quant_method == "per_token") def compute_in_loop(func: Callable, input: torch.Tensor, chunk_size: int, feature_size: Optional[int] = None, **kwargs): """ divides input into chunks in the leading dimension (dimension 0), and compute the chunks in a loop, instead of in a batch at once. arg: feature_size: size of output feature dimension. Provide it when the the output's feature dimension would differ from the input's feature dimension. """ total = input.shape[0] # directly compute if there is only one chunk if chunk_size >= total: return func(input, **kwargs) feature_size = feature_size or input.shape[1] output = input.new_empty(total, feature_size) num_chunks = (total + chunk_size - 1) // chunk_size for i in range(num_chunks): start = i * chunk_size end = min((i + 1) * chunk_size, total) output[start : end] = func(input[start : end], **kwargs) return output