# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import os import time #import tensorrt as trt import torch import torch.multiprocessing as mp from safetensors import safe_open from transformers import AutoModelForCausalLM, GPTNeoXConfig from weight import load_from_hf_gpt_neox import xtrt_llm from xtrt_llm._utils import str_dtype_to_xtrt from xtrt_llm.builder import Builder from xtrt_llm.logger import logger from xtrt_llm.mapping import Mapping from xtrt_llm.models import weight_only_groupwise_quantize, weight_only_quantize from xtrt_llm.network import net_guard from xtrt_llm.plugin.plugin import ContextFMHAType from xtrt_llm.quantization import QuantMode MODEL_NAME = "gptneox" hf_gpt = None class StateDict(): def __init__(self, quant_ckpt_dir): self.model_state_dict = safe_open(quant_ckpt_dir, framework="pt", device=0) def get(self, k): return self.model_state_dict.get_tensor(k).cpu() class GPTQModel(): def __init__(self, model_dir, quant_ckpt_dir): with open(model_dir + '/config.json', 'r') as f: model_config = json.load(f) self.config = GPTNeoXConfig() self.config.vocab_size = model_config['vocab_size'] self.config.hidden_size = model_config['hidden_size'] self.config.num_hidden_layers = model_config['num_hidden_layers'] self.config.num_attention_heads = model_config[ 'num_attention_heads'] self.config.intermediate_size = model_config['intermediate_size'] self.config.hidden_act = model_config['hidden_act'] self.config.rotary_pct = model_config['rotary_pct'] self.config.rotary_emb_base = model_config['rotary_emb_base'] self.config.max_position_embeddings = model_config[ 'max_position_embeddings'] self.config.initializer_range = model_config['initializer_range'] self.config.layer_norm_eps = model_config['layer_norm_eps'] self.config.use_cache = model_config['use_cache'] self.config.bos_token_id = model_config['bos_token_id'] self.config.eos_token_id = model_config['eos_token_id'] self.config.tie_word_embeddings = model_config[ 'tie_word_embeddings'] self.model_state_dict = StateDict(quant_ckpt_dir) def state_dict(self): return self.model_state_dict def get_engine_name(model, dtype, tp_size, rank): return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() engine.serialize(path) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=1, help='world size, only support tensor parallelism now') parser.add_argument( '--model_dir', type=str, default=None, help='The path to HF GPT-NeoX model / checkpoints to read weights from') parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'float32']) parser.add_argument( '--timing_cache', type=str, default='model.cache', help= 'The path of to read timing cache from, will be ignored if the file does not exist' ) parser.add_argument('--log_level', type=str, default='info') parser.add_argument('--vocab_size', type=int, default=50432) parser.add_argument('--n_layer', type=int, default=44) parser.add_argument('--n_positions', type=int, default=2048) parser.add_argument('--n_embd', type=int, default=6144) parser.add_argument('--n_head', type=int, default=64) parser.add_argument('--hidden_act', type=str, default='gelu') parser.add_argument( '--rotary_pct', type=float, default=0.25, help="Percentage of hidden dimensions to allocate to rotary embeddings." ) parser.add_argument('--max_batch_size', type=int, default=64) parser.add_argument('--max_input_len', type=int, default=1024) parser.add_argument('--max_output_len', type=int, default=1024) parser.add_argument('--max_beam_width', type=int, default=1) parser.add_argument('--use_gpt_attention_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'float32']) parser.add_argument('--use_gemm_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'float32']) parser.add_argument('--use_weight_only_quant_matmul_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16']) parser.add_argument('--use_weight_only_groupwise_quant_matmul_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16']) parser.add_argument( '--groupwise_quant_safetensors_path', type=str, default=None, help= "The path to groupwise quantized GPT-NeoX model / checkpoints to read weights from." ) parser.add_argument('--use_layernorm_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'float32']) parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--enable_context_fmha', default=False, action='store_true') parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', type=str, default='gpt_outputs', help= 'The path to save the serialized engine files, timing cache file and model configs' ) parser.add_argument('--remove_input_padding', default=False, action='store_true') parser.add_argument( '--use_parallel_embedding', action="store_true", default=False, help= 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' ) parser.add_argument( '--embedding_sharding_dim', type=int, default=1, # Meta does TP on hidden dim choices=[0, 1], help= 'By default the embedding lookup table is sharded along vocab dimension (--embedding_sharding_dim=0). ' 'To shard it along hidden dimension, set --embedding_sharding_dim=1' 'Note: embedding sharing is only enabled when --embedding_sharding_dim=0' ) parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument('--inter_size', type=int, default=None) args = parser.parse_args() xtrt_llm.logger.set_level(args.log_level) if args.model_dir is not None: global hf_gpt if not args.use_weight_only_groupwise_quant_matmul_plugin: logger.info(f'Loading HF GPT-NeoX model from {args.model_dir}...') hf_gpt = AutoModelForCausalLM.from_pretrained(args.model_dir) args.n_embd = hf_gpt.config.hidden_size args.n_head = hf_gpt.config.num_attention_heads args.n_layer = hf_gpt.config.num_hidden_layers args.n_positions = hf_gpt.config.max_position_embeddings args.vocab_size = hf_gpt.config.vocab_size args.rotary_pct = hf_gpt.config.rotary_pct else: assert ( args.groupwise_quant_safetensors_path is not None ), f'Please set the path to the groupwise quantized GPT-NeoX checkpoints with --groupwise_quant_safetensors_path' logger.info( f'Loading GPTQ quantized HF GPT-NeoX model from {args.groupwise_quant_safetensors_path}...' ) hf_gpt = GPTQModel(args.model_dir, args.groupwise_quant_safetensors_path) args.n_embd = hf_gpt.config.hidden_size args.n_head = hf_gpt.config.num_attention_heads args.n_layer = hf_gpt.config.num_hidden_layers args.n_positions = hf_gpt.config.max_position_embeddings args.vocab_size = hf_gpt.config.vocab_size args.rotary_pct = hf_gpt.config.rotary_pct args.inter_size = hf_gpt.config.intermediate_size if args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( args.weight_only_precision == 'int4') else: args.quant_mode = QuantMode(0) return args def build_rank_engine(builder: Builder, builder_config: xtrt_llm.builder.BuilderConfig, engine_name, rank, args): ''' @brief: Build the engine on the given rank. @param rank: The rank to build the engine. @param args: The cmd line arguments. @return: The built engine. ''' kv_dtype = str_dtype_to_xtrt(args.dtype) rotary_dim = int((args.n_embd // args.n_head) * args.rotary_pct) # Initialize Module xtrt_llm_gpt = xtrt_llm.models.GPTNeoXForCausalLM( num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, rotary_dim=rotary_dim, dtype=kv_dtype, mapping=Mapping(world_size=args.world_size, rank=rank, tp_size=args.world_size), # TP only apply_query_key_layer_scaling=builder_config. apply_query_key_layer_scaling, use_parallel_embedding=args.use_parallel_embedding, embedding_sharding_dim=args.embedding_sharding_dim) if args.use_weight_only_quant_matmul_plugin: xtrt_llm_gpt = weight_only_quantize(xtrt_llm_gpt) if args.use_weight_only_groupwise_quant_matmul_plugin: xtrt_llm_gpt = weight_only_groupwise_quantize(model=xtrt_llm_gpt, quant_mode=QuantMode(0), group_size=128, zero=True) if args.model_dir is not None: assert hf_gpt is not None, f'Could not load weights from hf_gpt model as it is not loaded yet.' if args.world_size > 1: assert ( args.n_embd % args.world_size == 0 ), f'Embedding size/hidden size must be divisible by world size.' assert ( args.n_head % args.world_size == 0 ), f'Number of attention heads must be divisible by world size.' load_from_hf_gpt_neox( xtrt_llm_gpt, hf_gpt, args.dtype, rank, args.world_size, args.use_weight_only_groupwise_quant_matmul_plugin) # Module -> Network network = builder.create_network() network.trt_network.name = engine_name if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) if args.use_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) if args.use_layernorm_plugin: network.plugin_config.set_layernorm_plugin( dtype=args.use_layernorm_plugin) assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) if args.enable_context_fmha: network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) if args.use_weight_only_quant_matmul_plugin: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype=args.use_weight_only_quant_matmul_plugin) if args.use_weight_only_groupwise_quant_matmul_plugin: network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( dtype=args.use_weight_only_groupwise_quant_matmul_plugin) if args.quant_mode.is_weight_only(): builder_config.trt_builder_config.use_weight_only = args.weight_only_precision if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() with net_guard(network): # Prepare network.set_named_parameters(xtrt_llm_gpt.named_parameters()) # Forward inputs = xtrt_llm_gpt.prepare_inputs(args.max_batch_size, args.max_input_len, args.max_output_len, True, args.max_beam_width) xtrt_llm_gpt(*inputs) #xtrt_llm.graph_rewriting.optimize(network) engine = None # Network -> Engine engine = builder.build_engine(network, builder_config, compiler="gr") if rank == 0: config_path = os.path.join(args.output_dir, 'config.json') builder.save_config(builder_config, config_path) return engine def build(rank, args): #torch.cuda.set_device(rank % args.gpus_per_node) xtrt_llm.logger.set_level(args.log_level) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # when doing serializing build, all ranks share one engine apply_query_key_layer_scaling = False builder = Builder() cache = None for cur_rank in range(args.world_size): # skip other ranks if parallel_build is enabled if args.parallel_build and cur_rank != rank: continue builder_config = builder.create_builder_config( name=MODEL_NAME, precision=args.dtype, timing_cache=args.timing_cache if cache is None else cache, tensor_parallel=args.world_size, # TP only parallel_build=args.parallel_build, num_layers=args.n_layer, num_heads=args.n_head, inter_size=args.inter_size, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, apply_query_key_layer_scaling=apply_query_key_layer_scaling, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, fusion_pattern_list=["remove_dup_mask"]) engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size, cur_rank) engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' # if cur_rank == 0: # # Use in-memory timing cache for multiple builder passes. # if not args.parallel_build: # cache = builder_config.trt_builder_config.get_timing_cache() serialize_engine(engine, os.path.join(args.output_dir, engine_name)) # if rank == 0: # ok = builder.save_timing_cache( # builder_config, os.path.join(args.output_dir, "model.cache")) # assert ok, "Failed to save timing cache." if __name__ == '__main__': args = parse_arguments() tik = time.time() if args.parallel_build and args.world_size > 1 and \ torch.cuda.device_count() >= args.world_size: logger.warning( f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' ) mp.spawn(build, nprocs=args.world_size, args=(args, )) else: args.parallel_build = False logger.info('Serially build TensorRT engines.') build(0, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Total time of building all {args.world_size} engines: {t}')