# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time from operator import attrgetter import numpy as np import torch import xtrt_llm from xtrt_llm._utils import pad_vocab_size from xtrt_llm.models import GPTNeoXForCausalLM from xtrt_llm._utils import str_dtype_to_torch UINT4_TO_INT4_FLAG = 1 GPTQ_FLAG = 1 GROUP_SIZE = 128 def numpy_split(v, tp_size, idx, dim=0): if tp_size == 1: return v else: return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) def torch_split(v, tp_size, idx, dim=0): if tp_size == 1: return v else: return (torch.split(v, v.shape[dim] // tp_size, dim=dim)[idx]).contiguous() def unpack_int32_into_int8(w_packed): # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) w_unpacked = torch.zeros(w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8) w_unpacked[:, ::2] = w_packed_int4x2 % 16 w_unpacked[:, 1::2] = w_packed_int4x2 // 16 return w_unpacked.contiguous() def preprocess_groupwise_weight_params(qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8): packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm qweight_interleaved = preprocessor(packer(qweight_unpacked_int8), torch.quint4x2).view(torch.float32) # zeros = zeros * scales zeros_x_scales_fp16 = (-qzeros_unpacked_int8 + 8 * UINT4_TO_INT4_FLAG - GPTQ_FLAG) * scales_fp16 zeros_x_scales_fp16 = zeros_x_scales_fp16.half() # return processed interleaved weight, original scales and zeros * scales return qweight_interleaved.contiguous().numpy(), scales_fp16.contiguous( ).numpy(), zeros_x_scales_fp16.contiguous().numpy() def load_from_hf_gpt_neox(xtrt_llm_gpt_neox: GPTNeoXForCausalLM, hf_gpt_neox, dtype="float32", rank=0, tp_size=1, use_weight_only_groupwise_quant_matmul_plugin=False): hf_model_gptneox_block_names = [ "input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias", ] xtrt_llm_model_gptneox_block_names = [ "input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias", ] if not use_weight_only_groupwise_quant_matmul_plugin: hf_model_gptneox_block_names += [ "attention.dense.weight", "attention.dense.bias", "mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.bias", "mlp.dense_4h_to_h.weight", "mlp.dense_4h_to_h.bias", ] xtrt_llm_model_gptneox_block_names += [ "attention.dense.weight", "attention.dense.bias", "mlp.fc.weight", "mlp.fc.bias", "mlp.proj.weight", "mlp.proj.bias", ] if not use_weight_only_groupwise_quant_matmul_plugin: xtrt_llm.logger.info('Loading weights from HF GPT-NeoX...') else: xtrt_llm.logger.info( 'Loading weights from GPTQ quantized HF GPT-NeoX...') tik = time.time() torch_dtype = str_dtype_to_torch(dtype) hf_gpt_neox_state_dict = hf_gpt_neox.state_dict() # [vocab_size, hidden_size] v = hf_gpt_neox_state_dict.get('gpt_neox.embed_in.weight').to( torch_dtype).cpu().numpy() if xtrt_llm_gpt_neox._use_parallel_embedding: v = numpy_split(v, tp_size, rank, xtrt_llm_gpt_neox._embedding_sharding_dim) xtrt_llm_gpt_neox.embedding.weight.value = v n_layer = hf_gpt_neox.config.num_hidden_layers for layer_idx in range(n_layer): prefix = "gpt_neox.layers." + str(layer_idx) + "." for idx, hf_attr in enumerate(hf_model_gptneox_block_names): v = hf_gpt_neox_state_dict.get(prefix + hf_attr).to( torch_dtype).cpu().numpy() layer = attrgetter(xtrt_llm_model_gptneox_block_names[idx])( xtrt_llm_gpt_neox.layers[layer_idx]) if tp_size > 1: if 'dense.weight' in hf_attr: # [n=hidden_size, k=hidden_size] -> # [n=hidden_size, k=hidden_size // tp_size] split_v = numpy_split(v, tp_size, rank, dim=1) elif 'dense_h_to_4h.weight' in hf_attr: # [hidden_size * 4, hidden_size] -> # [hidden_size * 4 // tp_size, hidden_size] split_v = numpy_split(v, tp_size, rank, dim=0) elif 'dense_h_to_4h.bias' in hf_attr: # [hidden_size * 4] -> [hidden_size * 4 // tp_size] split_v = numpy_split(v, tp_size, rank, dim=0) elif 'dense_4h_to_h.weight' in hf_attr: # [hidden_size, hidden_size * 4] -> # [hidden_size, hidden_size * 4 // tp_size] split_v = numpy_split(v, tp_size, rank, dim=1) else: split_v = v setattr(layer, 'value', split_v) else: setattr(layer, 'value', v) num_heads = hf_gpt_neox.config.num_attention_heads hidden_size = hf_gpt_neox.config.hidden_size head_size = hidden_size // num_heads if not use_weight_only_groupwise_quant_matmul_plugin: # Attention QKV Linear # qkv_weights [num_heads x (q|k|v), hidden_size] -> # [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size] qkv_weights = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.weight") qkv_bias = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.bias") new_qkv_weight_shape = torch.Size( [num_heads, 3, head_size * qkv_weights.size()[-1]]) new_qkv_bias_shape = torch.Size([num_heads, 3, head_size]) qkv_weights = qkv_weights.view(new_qkv_weight_shape).permute( 1, 0, 2).reshape([hidden_size * 3, hidden_size]) qkv_bias = qkv_bias.view(new_qkv_bias_shape).permute( 1, 0, 2).reshape([hidden_size * 3]) if tp_size > 1: qkv_weights = qkv_weights.reshape( 3, hidden_size, hidden_size).to(torch_dtype).cpu().numpy() split_qkv_weights = numpy_split( qkv_weights, tp_size, rank, dim=1).reshape(3 * (hidden_size // tp_size), hidden_size) xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.weight.value = \ np.ascontiguousarray(split_qkv_weights) qkv_bias = qkv_bias.reshape( 3, hidden_size).to(torch_dtype).cpu().numpy() split_qkv_bias = numpy_split(qkv_bias, tp_size, rank, dim=1).reshape( 3 * (hidden_size // tp_size)) xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \ np.ascontiguousarray(split_qkv_bias) else: xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.weight.value = \ qkv_weights.to(torch_dtype).cpu().numpy() xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \ qkv_bias.to(torch_dtype).cpu().numpy() else: # use_weight_only_groupwise_quant_matmul_plugin qweight_int32 = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.qweight") scales_fp16 = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.scales") qzeros_int32 = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.qzeros") biases_fp16 = hf_gpt_neox_state_dict.get( prefix + "attention.query_key_value.bias") # [hidden_size // 8, hidden_size * 3] -> [hidden_size * 3, hidden_size] qweight_unpacked_int8 = unpack_int32_into_int8( qweight_int32.T).contiguous() - 8 # [hidden_size // GROUP_SIZE, hidden_size * 3 // 8] -> # [hidden_size // GROUP_SIZE, hidden_size * 3] qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32) # qkv_weights [num_heads x (q|k|v), hidden_size] -> # [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size] new_qkv_weight_shape = torch.Size( [num_heads, 3, head_size * qweight_unpacked_int8.size()[-1]]) # [hidden_size * 3, hidden_size] qweight_unpacked_int8 = qweight_unpacked_int8.view( new_qkv_weight_shape).permute(1, 0, 2).reshape( [hidden_size * 3, hidden_size]).contiguous() new_qkv_scale_shape = torch.Size( [num_heads, 3, head_size * (hidden_size // GROUP_SIZE)]) # [hidden_size * 3, hidden_size // GROUP_SIZE] scales_fp16 = scales_fp16.T.contiguous().view( new_qkv_scale_shape).permute(1, 0, 2).reshape( [hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous() new_qkv_zero_shape = torch.Size( [num_heads, 3, head_size * (hidden_size // GROUP_SIZE)]) # [hidden_size * 3, hidden_size // GROUP_SIZE] qzeros_unpacked_int8 = qzeros_unpacked_int8.T.contiguous().view( new_qkv_zero_shape).permute(1, 0, 2).reshape( [hidden_size * 3, hidden_size // GROUP_SIZE]).contiguous() new_qkv_bias_shape = torch.Size([num_heads, 3, head_size]) biases_fp16 = biases_fp16.view(new_qkv_bias_shape).permute( 1, 0, 2).reshape([hidden_size * 3]).numpy() if tp_size > 1: qweight_unpacked_int8 = qweight_unpacked_int8.reshape( [3, hidden_size, hidden_size]) qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, tp_size, rank, dim=1) qweight_unpacked_int8 = qweight_unpacked_int8.reshape( [3 * hidden_size // tp_size, hidden_size]) scales_fp16 = scales_fp16.reshape( [3, hidden_size, hidden_size // GROUP_SIZE]) scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=1) scales_fp16 = scales_fp16.reshape( [3 * hidden_size // tp_size, hidden_size // GROUP_SIZE]) qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape( [3, hidden_size, hidden_size // GROUP_SIZE]) qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, tp_size, rank, dim=1) qzeros_unpacked_int8 = qzeros_unpacked_int8.reshape( [3 * hidden_size // tp_size, hidden_size // GROUP_SIZE]) biases_fp16 = biases_fp16.reshape([3, hidden_size]) biases_fp16 = numpy_split(biases_fp16, tp_size, rank, dim=1) biases_fp16 = biases_fp16.reshape([3 * hidden_size // tp_size]) qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params( qweight_unpacked_int8.T.contiguous(), scales_fp16.T.contiguous(), qzeros_unpacked_int8.T.contiguous()) xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.qweight.value = \ qweight_fp32 xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.scale.value = \ scales_fp16 xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.zero.value = \ zeros_fp16 xtrt_llm_gpt_neox.layers[layer_idx].attention.qkv.bias.value = \ biases_fp16 qweight_int32 = hf_gpt_neox_state_dict.get( prefix + "attention.dense.qweight") scales_fp16 = hf_gpt_neox_state_dict.get(prefix + "attention.dense.scales") qzeros_int32 = hf_gpt_neox_state_dict.get(prefix + "attention.dense.qzeros") biases_fp16 = hf_gpt_neox_state_dict.get( prefix + "attention.dense.bias").numpy() # [k=hidden_size // 8, n=hidden_size] -> [n=hidden_size, k=hidden_size] qweight_unpacked_int8 = unpack_int32_into_int8( qweight_int32.T).contiguous() - 8 # [n=hidden_size, k=hidden_size] -> [k=hidden_size, n=hidden_size] qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous() # [k=hidden_size // GROUP_SIZE, n=hidden_size // 8] -> # [k=hidden_size // GROUP_SIZE, n=hidden_size] qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32) if tp_size > 1: qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, tp_size, rank, dim=0) scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=0) qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, tp_size, rank, dim=0) if rank > 0: biases_fp16 = np.zeros_like(biases_fp16) qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params( qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8) xtrt_llm_gpt_neox.layers[layer_idx].attention.dense.qweight.value = \ qweight_fp32 xtrt_llm_gpt_neox.layers[layer_idx].attention.dense.scale.value = \ scales_fp16 xtrt_llm_gpt_neox.layers[layer_idx].attention.dense.zero.value = \ zeros_fp16 xtrt_llm_gpt_neox.layers[layer_idx].attention.dense.bias.value = \ biases_fp16 qweight_int32 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_h_to_4h.qweight") scales_fp16 = hf_gpt_neox_state_dict.get(prefix + "mlp.dense_h_to_4h.scales") qzeros_int32 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_h_to_4h.qzeros") biases_fp16 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_h_to_4h.bias").numpy() # [hidden_size // 8, hidden_size * 4] -> [hidden_size, hidden_size * 4] qweight_unpacked_int8 = unpack_int32_into_int8( qweight_int32.T).contiguous() - 8 qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous() # [hidden_size // GROUP_SIZE, hidden_size * 4 // 8] -> # [hidden_size // GROUP_SIZE, hidden_size * 4] qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32) if tp_size > 1: # [hidden_size, hidden_size * 4] -> # [hidden_size, hidden_size * 4 // tp_size] qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, tp_size, rank, dim=1) # [hidden_size // GROUP_SIZE, hidden_size * 4] -> # [hidden_size // GROUP_SIZE, hidden_size * 4 // tp_size] scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=1) # [hidden_size // GROUP_SIZE, hidden_size * 4] -> # [hidden_size // GROUP_SIZE, hidden_size * 4 // tp_size] qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, tp_size, rank, dim=1) # [hidden_size * 4] -> [hidden_size * 4 // tp_size] biases_fp16 = numpy_split(biases_fp16, tp_size, rank, dim=0) qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params( qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8) xtrt_llm_gpt_neox.layers[layer_idx].mlp.fc.qweight.value = \ qweight_fp32 xtrt_llm_gpt_neox.layers[layer_idx].mlp.fc.scale.value = \ scales_fp16 xtrt_llm_gpt_neox.layers[layer_idx].mlp.fc.zero.value = \ zeros_fp16 xtrt_llm_gpt_neox.layers[layer_idx].mlp.fc.bias.value = \ biases_fp16 qweight_int32 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_4h_to_h.qweight") scales_fp16 = hf_gpt_neox_state_dict.get(prefix + "mlp.dense_4h_to_h.scales") qzeros_int32 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_4h_to_h.qzeros") biases_fp16 = hf_gpt_neox_state_dict.get( prefix + "mlp.dense_4h_to_h.bias").numpy() # [hidden_size * 4 // 8, hidden_size] -> [hidden_size * 4, hidden_size] qweight_unpacked_int8 = unpack_int32_into_int8( qweight_int32.T).contiguous() - 8 qweight_unpacked_int8 = qweight_unpacked_int8.T.contiguous() # [hidden_size * 4 // GROUP_SIZE, hidden_size // 8] -> # [hidden_size * 4 // GROUP_SIZE, hidden_size] qzeros_unpacked_int8 = unpack_int32_into_int8(qzeros_int32) if tp_size > 1: # [hidden_size * 4, hidden_size] -> # [hidden_size * 4 // tp_size, hidden_size] qweight_unpacked_int8 = torch_split(qweight_unpacked_int8, tp_size, rank, dim=0) # [hidden_size * 4 // GROUP_SIZE, hidden_size] -> # [hidden_size * 4 // GROUP_SIZE // tp_size, hidden_size] -> scales_fp16 = torch_split(scales_fp16, tp_size, rank, dim=0) # [hidden_size * 4 // GROUP_SIZE, hidden_size] -> # [hidden_size * 4 // GROUP_SIZE // tp_size, hidden_size] qzeros_unpacked_int8 = torch_split(qzeros_unpacked_int8, tp_size, rank, dim=0) if rank > 0: biases_fp16 = np.zeros_like(biases_fp16) qweight_fp32, scales_fp16, zeros_fp16 = preprocess_groupwise_weight_params( qweight_unpacked_int8, scales_fp16, qzeros_unpacked_int8) xtrt_llm_gpt_neox.layers[layer_idx].mlp.proj.qweight.value = \ qweight_fp32 xtrt_llm_gpt_neox.layers[layer_idx].mlp.proj.scale.value = \ scales_fp16 xtrt_llm_gpt_neox.layers[layer_idx].mlp.proj.zero.value = \ zeros_fp16 xtrt_llm_gpt_neox.layers[layer_idx].mlp.proj.bias.value = \ biases_fp16 v = hf_gpt_neox_state_dict.get('gpt_neox.final_layer_norm.weight') xtrt_llm_gpt_neox.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() v = hf_gpt_neox_state_dict.get('gpt_neox.final_layer_norm.bias') xtrt_llm_gpt_neox.ln_f.bias.value = v.to(torch_dtype).cpu().numpy() v = hf_gpt_neox_state_dict.get('embed_out.weight').to( torch_dtype).cpu().numpy() if tp_size > 1: # [vocab_size, hidden_size] -> # [vocab_size // tp_size, hidden_size] if v.shape[0] % tp_size != 0: # padding vocab_size_padded = pad_vocab_size(v.shape[0], tp_size) pad_width = vocab_size_padded - v.shape[0] v = np.pad(v, ((0, pad_width), (0, 0)), 'constant', constant_values=0) split_v = numpy_split(v, tp_size, rank, dim=0) xtrt_llm_gpt_neox.lm_head.weight.value = split_v else: xtrt_llm_gpt_neox.lm_head.weight.value = v tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')