# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py # Copyright 2023 The vLLM team. # Copyright 2022 HuggingFace Inc. team and BigScience workshop. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math from collections.abc import Iterable from typing import Optional, Union import torch from torch import nn from transformers import BloomConfig import os import re from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm import _custom_ops as ops from vllm.model_executor.utils import pad_weight, gemm_bank_conf from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( 2**(-(2**-(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, total_num_heads - closest_power_of_2) extra_powers = torch.arange(start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32) slopes = torch.cat( [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.total_num_heads = config.n_head self.head_dim = self.hidden_size // self.total_num_heads assert self.head_dim * self.total_num_heads == self.hidden_size tp_world_size = get_tensor_model_parallel_world_size() assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, bias=True, quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, ) # Create the alibi slopes and slice them. tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 self.attn = Attention(self.num_heads, self.head_dim, scaling, alibi_slopes=alibi_slopes, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn") self.quant_method = None if quant_config is not None: self.quant_method=quant_config.get_name() self.quant_config=quant_config def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) # if os.environ.get('FA_PAD') == '1' and self.quant_method is None: # qkv = qkv[...,:-32] q, k, v = qkv.chunk(chunks=3, dim=-1) attn_output = self.attn(q, k, v) output, _ = self.dense(attn_output) return output class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, quant_config=quant_config, ) self.gelu_impl = get_act_fn("gelu") self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.dense_h_to_4h(x) x = self.gelu_impl(x) x, _ = self.dense_4h_to_h(x) return x class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.self_attention = BloomAttention(config, cache_config, quant_config, prefix=f"{prefix}.self_attention") self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) def forward( self, position_ids: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self attention. attention_output = self.self_attention( position_ids=position_ids, hidden_states=layernorm_output, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # MLP. output = self.mlp(layernorm_output) + residual return output @support_torch_compile class BloomModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config self.embed_dim = config.hidden_size # Embedding + LN Embedding self.word_embeddings = VocabParallelEmbedding( config.vocab_size, self.embed_dim, ) self.word_embeddings_layernorm = nn.LayerNorm( self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, lambda prefix: BloomBlock( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) self.quant_method = None if quant_config is not None: self.quant_method=quant_config.get_name() self.quant_config=quant_config self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1' def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.word_embeddings_layernorm(self.word_embeddings(input_ids)) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] for layer in self.h[self.start_layer:self.end_layer]: hidden_states = layer(position_ids, hidden_states) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue param = params_dict[name] if "query_key_value" in name: # NOTE: BLOOM's fused QKV's output_dim has the shape of # (num_heads * 3 * head_size), while the # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads if output_dim is not None: loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + loaded_weight_shape[output_dim + 1:]) loaded_weight = loaded_weight.transpose( output_dim, output_dim + 1) loaded_weight = loaded_weight.reshape(loaded_weight_shape) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) if self.use_llama_nn and self.quant_method is None: lay_key_words = [ "self_attention.query_key_value.weight", "self_attention.dense.weight", "mlp.dense_h_to_4h.weight", "mlp.dense_4h_to_h.weight" ] combined_words = "|".join(lay_key_words) # lay_qkv_words = ["self_attention.query_key_value.weight"] # qkv_words = "|".join(lay_qkv_words) # lay_qkv_bias_words = ["self_attention.query_key_value.bias"] # qkv_bias_words = "|".join(lay_qkv_bias_words) for layername in loaded_params: weight = params_dict[layername] # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): # weight.data = pad_weight(weight.data, 32) matches = re.findall(combined_words, layername) if matches: # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): # weight.data = pad_weight(weight.data, 32) # if self.use_fa_pad and (re.findall(qkv_words, layername)): # if not gemm_bank_conf(weight.data.shape[0]): # weight.data = pad_weight(weight.data, 32) _weight = torch.zeros_like(weight.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) weight.data.copy_(_weight) weight.data=weight.data.reshape(ori_shape[1],-1) return loaded_params class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.transformer = BloomModel(vllm_config=vllm_config, prefix=maybe_prefix( prefix, "transformer")) if self.config.tie_word_embeddings: self.lm_head = self.transformer.word_embeddings else: self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"]) weights = _add_transformer_prefix(weights) return loader.load_weights(weights) def _add_transformer_prefix( weights: Iterable[tuple[str, torch.Tensor]] ) -> Iterable[tuple[str, torch.Tensor]]: for name, tensor in weights: if not name.startswith('transformer.'): name = 'transformer.' + name yield name, tensor