# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import time import onnx import torch.multiprocessing as mp import tvm as trt from onnx import TensorProto, helper from transformers import AutoConfig, AutoModelForCausalLM import xtrt_llm from xtrt_llm._utils import str_dtype_to_xtrt from xtrt_llm.builder import Builder from xtrt_llm.layers.attention import PositionEmbeddingType from xtrt_llm.logger import logger from xtrt_llm.mapping import Mapping from xtrt_llm.models import BaichuanForCausalLM, weight_only_quantize from xtrt_llm.network import net_guard from xtrt_llm.plugin.plugin import ContextFMHAType from xtrt_llm.quantization import QuantMode from weight import load_from_hf_baichuan # isort:skip # 2 routines: get_engine_name, serialize_engine # are direct copy from gpt example, TODO: put in utils? def trt_dtype_to_onnx(dtype): if dtype == trt.float16: return TensorProto.DataType.FLOAT16 elif dtype == trt.float32: return TensorProto.DataType.FLOAT elif dtype == trt.int32: return TensorProto.DataType.INT32 else: raise TypeError("%s is not supported" % dtype) def to_onnx(network, path): inputs = [] for i in range(network.num_inputs): network_input = network.get_input(i) inputs.append( helper.make_tensor_value_info( network_input.name, trt_dtype_to_onnx(network_input.dtype), list(network_input.shape))) outputs = [] for i in range(network.num_outputs): network_output = network.get_output(i) outputs.append( helper.make_tensor_value_info( network_output.name, trt_dtype_to_onnx(network_output.dtype), list(network_output.shape))) nodes = [] for i in range(network.num_layers): layer = network.get_layer(i) layer_inputs = [] for j in range(layer.num_inputs): ipt = layer.get_input(j) if ipt is not None: layer_inputs.append(layer.get_input(j).name) layer_outputs = [ layer.get_output(j).name for j in range(layer.num_outputs) ] nodes.append( helper.make_node(str(layer.type), name=layer.name, inputs=layer_inputs, outputs=layer_outputs, domain="com.nvidia")) onnx_model = helper.make_model(helper.make_graph(nodes, 'attention', inputs, outputs, initializer=None), producer_name='NVIDIA') onnx.save(onnx_model, path) def get_engine_name(model, dtype, tp_size, rank): return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) def serialize_engine(engine, path): logger.info(f'Serializing engine to {path}...') tik = time.time() # import pdb;pdb.set_trace() engine.serialize(path) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Engine serialized. Total time: {t}') def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--world_size', type=int, default=1, help='world size, only support tensor parallelism now') parser.add_argument('--model_dir', type=str, default='baichuan-inc/Baichuan-13B-Chat') parser.add_argument('--model_version', type=str, default='v1_13b', choices=['v1_7b', 'v1_13b', 'v2_7b', 'v2_13b']) parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) parser.add_argument( '--opt_memory_use', default=True, action="store_true", help='Whether to use Host memory optimization for building engine') parser.add_argument( '--timing_cache', type=str, default='model.cache', help= 'The path of to read timing cache from, will be ignored if the file does not exist' ) parser.add_argument('--log_level', type=str, default='info') parser.add_argument('--pp_size', type=int, default=1) parser.add_argument('--vocab_size', type=int, default=64000) parser.add_argument('--n_layer', type=int, default=40) parser.add_argument('--n_positions', type=int, default=4096) parser.add_argument('--n_embd', type=int, default=5120) parser.add_argument('--n_head', type=int, default=40) parser.add_argument('--inter_size', type=int, default=13696) parser.add_argument('--hidden_act', type=str, default='silu') parser.add_argument('--max_batch_size', type=int, default=1) parser.add_argument('--max_input_len', type=int, default=1024) parser.add_argument('--max_output_len', type=int, default=1024) parser.add_argument('--max_beam_width', type=int, default=1) parser.add_argument('--use_gpt_attention_plugin', nargs='?', const='float16', type=str, default=True, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--use_gemm_plugin', nargs='?', const='float16', type=str, default=False, choices=['float16', 'bfloat16', 'float32']) parser.add_argument('--enable_context_fmha', default=False, action='store_true') parser.add_argument('--enable_context_fmha_fp32_acc', default=False, action='store_true') parser.add_argument('--parallel_build', default=False, action='store_true') parser.add_argument('--visualize', default=False, action='store_true') parser.add_argument('--enable_debug_output', default=False, action='store_true') parser.add_argument('--gpus_per_node', type=int, default=8) parser.add_argument( '--output_dir', type=str, default='baichuan_outputs', help= 'The path to save the serialized engine files, timing cache file and model configs' ) parser.add_argument('--remove_input_padding', default=False, action='store_true') parser.add_argument( '--use_weight_only', default=False, action="store_true", help='Quantize weights for the various GEMMs to INT4/INT8.' 'See --weight_only_precision to set the precision') parser.add_argument( '--weight_only_precision', const='int8', type=str, nargs='?', default='int8', choices=['int8', 'int4'], help= 'Define the precision for the weights when using weight-only quantization.' 'You must also use --use_weight_only for that argument to have an impact.' ) parser.add_argument( '--use_inflight_batching', action="store_true", default=False, help="Activates inflight batching mode of gptAttentionPlugin.") parser.add_argument( '--paged_kv_cache', action="store_true", default=False, help= 'By default we use contiguous KV cache. By setting this flag you enable paged KV cache' ) parser.add_argument('--tokens_per_block', type=int, default=64, help='Number of tokens per block in paged KV cache') parser.add_argument( '--max_num_tokens', type=int, default=None, help='Define the max number of tokens supported by the engine') parser.add_argument('--gather_all_token_logits', action='store_true', default=False) args = parser.parse_args() if args.use_weight_only: args.quant_mode = QuantMode.use_weight_only( args.weight_only_precision == 'int4') else: args.quant_mode = QuantMode(0) if args.use_inflight_batching: if not args.use_gpt_attention_plugin: args.use_gpt_attention_plugin = 'float16' logger.info( f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" ) if not args.remove_input_padding: args.remove_input_padding = True logger.info( "Using remove input padding for inflight batching mode.") if not args.paged_kv_cache: args.paged_kv_cache = True logger.info("Using paged KV cache for inflight batching mode.") if args.max_num_tokens is not None: assert args.enable_context_fmha if args.model_dir is not None: hf_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) # override the inter_size for Baichuan args.inter_size = hf_config.intermediate_size args.n_embd = hf_config.hidden_size args.n_head = hf_config.num_attention_heads args.n_layer = hf_config.num_hidden_layers if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': args.n_positions = hf_config.max_position_embeddings else: args.n_positions = hf_config.model_max_length args.vocab_size = hf_config.vocab_size args.hidden_act = hf_config.hidden_act else: # default values are based on v1_13b, change them based on model_version if args.model_version == 'v1_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 64000 args.hidden_act = 'silu' elif args.model_version == 'v2_7b': args.inter_size = 11008 args.n_embd = 4096 args.n_head = 32 args.n_layer = 32 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' elif args.model_version == 'v2_13b': args.inter_size = 13696 args.n_embd = 5120 args.n_head = 40 args.n_layer = 40 args.n_positions = 4096 args.vocab_size = 125696 args.hidden_act = 'silu' if args.dtype == 'bfloat16': assert args.use_gemm_plugin, "Please use gemm plugin when dtype is bfloat16" return args def build_rank_engine(builder: Builder, builder_config: xtrt_llm.builder.BuilderConfig, engine_name, rank, args): ''' @brief: Build the engine on the given rank. @param rank: The rank to build the engine. @param args: The cmd line arguments. @return: The built engine. ''' kv_dtype = str_dtype_to_xtrt(args.dtype) if args.model_version == 'v1_7b' or args.model_version == 'v2_7b': position_embedding_type = PositionEmbeddingType.rope_gpt_neox else: position_embedding_type = PositionEmbeddingType.alibi # Initialize Module xtrt_llm_baichuan = BaichuanForCausalLM( num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, position_embedding_type=position_embedding_type, dtype=kv_dtype, mlp_hidden_size=args.inter_size, mapping=Mapping(world_size=args.world_size, rank=rank, tp_size=args.world_size), gather_all_token_logits=args.gather_all_token_logits) if args.use_weight_only and args.weight_only_precision == 'int8' and 0: xtrt_llm_baichuan = weight_only_quantize(xtrt_llm_baichuan, QuantMode.use_weight_only()) elif args.use_weight_only and args.weight_only_precision == 'int4' and 0: xtrt_llm_baichuan = weight_only_quantize( xtrt_llm_baichuan, QuantMode.use_weight_only(use_int4_weights=True)) if args.model_dir is not None: logger.info( f'Loading HF Baichuan {args.model_version} ... from {args.model_dir}' ) tik = time.time() hf_baichuan = AutoModelForCausalLM.from_pretrained( args.model_dir, device_map={ "model": "cpu", "lm_head": "cpu" }, # Load to CPU memory torch_dtype="auto", trust_remote_code=True) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'HF Baichuan {args.model_version} loaded. Total time: {t}') load_from_hf_baichuan(xtrt_llm_baichuan, hf_baichuan, args.model_version, rank, args.world_size, dtype=args.dtype) del hf_baichuan # Module -> Network network = builder.create_network() network.trt_network.name = engine_name if args.use_gpt_attention_plugin: network.plugin_config.set_gpt_attention_plugin( dtype=args.use_gpt_attention_plugin) if args.use_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) if args.enable_context_fmha: network.plugin_config.set_context_fmha(ContextFMHAType.enabled) if args.enable_context_fmha_fp32_acc: network.plugin_config.set_context_fmha( ContextFMHAType.enabled_with_fp32_acc) if args.use_weight_only: network.plugin_config.set_weight_only_quant_matmul_plugin( dtype='float16') builder_config.trt_builder_config.use_weight_only = args.weight_only_precision if args.world_size > 1: network.plugin_config.set_nccl_plugin(args.dtype) if args.remove_input_padding: network.plugin_config.enable_remove_input_padding() if args.paged_kv_cache: network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) with net_guard(network): # Prepare network.set_named_parameters(xtrt_llm_baichuan.named_parameters()) # Forward inputs = xtrt_llm_baichuan.prepare_inputs(args.max_batch_size, args.max_input_len, args.max_output_len, True, args.max_beam_width, args.max_num_tokens) xtrt_llm_baichuan(*inputs) if args.enable_debug_output: # mark intermediate nodes' outputs for k, v in xtrt_llm_baichuan.named_network_outputs(): v = v.trt_tensor v.name = k network.trt_network.mark_output(v) v.dtype = kv_dtype if args.visualize: model_path = os.path.join(args.output_dir, 'test.onnx') to_onnx(network.trt_network, model_path) engine = None # Network -> Engine engine = builder.build_engine(network, builder_config, compiler="gr") if rank == 0: config_path = os.path.join(args.output_dir, 'config.json') builder.save_config(builder_config, config_path) if args.opt_memory_use: return engine, network return engine def build(rank, args): # torch.cuda.set_device(rank % args.gpus_per_node) xtrt_llm.logger.set_level(args.log_level) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # when doing serializing build, all ranks share one engine builder = Builder() cache = None model_name = 'baichuan' for cur_rank in range(args.world_size): # skip other ranks if parallel_build is enabled if args.parallel_build and cur_rank != rank: continue builder_config = builder.create_builder_config( name=model_name, precision=args.dtype, timing_cache=args.timing_cache if cache is None else cache, tensor_parallel=args.world_size, # TP only parallel_build=args.parallel_build, pipeline_parallel=args.pp_size, num_layers=args.n_layer, num_heads=args.n_head, hidden_size=args.n_embd, inter_size=args.inter_size, vocab_size=args.vocab_size, hidden_act=args.hidden_act, max_position_embeddings=args.n_positions, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, max_output_len=args.max_output_len, max_num_tokens=args.max_num_tokens, int8=args.quant_mode.has_act_and_weight_quant(), quant_mode=args.quant_mode, fusion_pattern_list=["remove_dup_mask"], gather_all_token_logits=args.gather_all_token_logits, ) guard = xtrt_llm.fusion_patterns.FuseonPatternGuard() print(guard) engine_name = get_engine_name(model_name, args.dtype, args.world_size, cur_rank) if args.opt_memory_use: engine, network = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) else: engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) assert engine is not None, f'Failed to build engine for rank {cur_rank}' serialize_engine(engine, os.path.join(args.output_dir, engine_name)) if __name__ == '__main__': args = parse_arguments() logger.set_level(args.log_level) tik = time.time() if args.parallel_build and args.world_size > 1: logger.warning( f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.' ) mp.spawn(build, nprocs=args.world_size, args=(args, )) else: args.parallel_build = False logger.info('Serially build TensorRT engines.') build(0, args) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) logger.info(f'Total time of building all {args.world_size} engines: {t}')