# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import csv import os from pathlib import Path import numpy as np import torch from utils import (DEFAULT_HF_MODEL_DIRS, DEFAULT_PROMPT_TEMPLATES, load_tokenizer, read_model_name_from_config, throttle_generator) import xtrt_llm from xtrt_llm.logger import logger from xtrt_llm.runtime import ModelRunner, read_config def parse_arguments(args=None): parser = argparse.ArgumentParser() parser.add_argument('--max_output_len', type=int, required=True) parser.add_argument('--max_kv_cache_length', type=int, default=None, help='The max kv cache length. \ If the final sequence length exceeds the kv cache length, we will enable cyclic kv cache. \ If it is set to None, we will use the max sequence length.') parser.add_argument('--log_level', type=str, default='error') parser.add_argument('--engine_dir', type=str, default='engine_outputs') parser.add_argument( '--input_text', type=str, nargs='+', default=["Born in north-east France, Soyer trained as a"]) parser.add_argument( '--no_prompt_template', dest='use_prompt_template', default=True, action='store_false', help= "Whether or not to use default prompt template to wrap the input text.") parser.add_argument( '--input_file', type=str, help= 'CSV or Numpy file containing tokenized input. Alternative to text input.', default=None) parser.add_argument('--max_input_length', type=int, default=923) parser.add_argument('--output_csv', type=str, help='CSV file where the tokenized output is stored.', default=None) parser.add_argument('--output_npy', type=str, help='Numpy file where the tokenized output is stored.', default=None) parser.add_argument( '--output_logits_npy', type=str, help= 'Numpy file where the generation logits are stored. Use only when num_beams==1', default=None) parser.add_argument('--tokenizer_dir', help="HF tokenizer config path", default='gpt2') parser.add_argument('--vocab_file', help="Used for sentencepiece tokenizers") parser.add_argument('--num_beams', type=int, help="Use beam search if num_beams >1", default=1) parser.add_argument('--temperature', type=float, default=1.0) parser.add_argument('--top_k', type=int, default=1) parser.add_argument('--top_p', type=float, default=0.0) parser.add_argument('--length_penalty', type=float, default=1.0) parser.add_argument('--repetition_penalty', type=float, default=1.0) parser.add_argument('--debug_mode', default=False, action='store_true', help="Whether or not to turn on the debug mode") parser.add_argument('--no_add_special_tokens', dest='add_special_tokens', default=True, action='store_false', help="Whether or not to add special tokens") parser.add_argument('--streaming', default=False, action='store_true') parser.add_argument('--streaming_interval', type=int, help="How often to return tokens when streaming.", default=5) parser.add_argument( '--prompt_table_path', type=str, help="Path to .npy file, exported by nemo_prompt_convert.py") parser.add_argument( '--prompt_tasks', help="Comma-separated list of tasks for prompt tuning, e.g., 0,3,1,0") parser.add_argument('--lora_dir', type=str, default=None, help="The directory of LoRA weights") parser.add_argument( '--lora_task_uids', type=str, default=None, nargs="+", help="The list of LoRA task uids; use -1 to disable the LoRA module") parser.add_argument( '--performance_test_scale', type=str, help= "Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)" ) return parser.parse_args(args=args) def parse_input(tokenizer, input_text=None, prompt_template=None, input_file=None, add_special_tokens=True, max_input_length=923, pad_id=None): if pad_id is None: pad_id = tokenizer.pad_token_id batch_input_ids = [] if input_file is None: for curr_text in input_text: if prompt_template is not None: curr_text = prompt_template.format(input_text=curr_text) input_ids = tokenizer.encode(curr_text, add_special_tokens=add_special_tokens, truncation=True, max_length=max_input_length) batch_input_ids.append(input_ids) else: if input_file.endswith('.csv'): with open(input_file, 'r') as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') for line in csv_reader: input_ids = np.array(line, dtype='int32') batch_input_ids.append(input_ids[-max_input_length:]) elif input_file.endswith('.npy'): inputs = np.load(input_file) for row in inputs: input_ids = row[row != pad_id] batch_input_ids.append(input_ids[-max_input_length:]) elif input_file.endswith('.txt'): with open(input_file, 'r', encoding='utf-8', errors='replace') as txt_file: input_text = txt_file.read() input_ids = tokenizer.encode( input_text, add_special_tokens=add_special_tokens, truncation=True, max_length=max_input_length) batch_input_ids.append(input_ids) else: print('Input file format not supported.') raise SystemExit batch_input_ids = [ torch.tensor(x, dtype=torch.int32).unsqueeze(0) for x in batch_input_ids ] return batch_input_ids def print_output(tokenizer, output_ids, input_lengths, sequence_lengths, output_csv=None, output_npy=None, context_logits=None, generation_logits=None, output_logits_npy=None): batch_size, num_beams, _ = output_ids.size() if output_csv is None and output_npy is None: for batch_idx in range(batch_size): inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist( ) input_text = tokenizer.decode(inputs) print(f'Input idx: [Text {batch_idx}]') print(f'Input: \"{input_text}\"') for beam in range(num_beams): output_begin = input_lengths[batch_idx] output_end = sequence_lengths[batch_idx][beam] outputs = output_ids[batch_idx][beam][ output_begin:output_end].tolist() output_text = tokenizer.decode(outputs) print(f'Output idx: [Text {batch_idx} Beam {beam}]') print(f'Output: \"{output_text}\"') output_ids = output_ids.reshape((-1, output_ids.size(2))) if output_csv is not None: output_file = Path(output_csv) output_file.parent.mkdir(exist_ok=True, parents=True) outputs = output_ids.tolist() with open(output_file, 'w') as csv_file: writer = csv.writer(csv_file, delimiter=',') writer.writerows(outputs) if output_npy is not None: output_file = Path(output_npy) output_file.parent.mkdir(exist_ok=True, parents=True) outputs = np.array(output_ids.cpu().contiguous(), dtype='int32') np.save(output_file, outputs) if generation_logits is not None and output_logits_npy is not None and num_beams == 1: input_lengths = torch.Tensor(input_lengths) context_logits = torch.cat(context_logits, axis=0) generation_logits = [logit.unsqueeze(1) for logit in generation_logits] generation_logits = torch.cat(generation_logits, axis=1) last_token_ids = torch.cumsum(input_lengths, dim=0).int().cuda() batch_size = input_lengths.size(0) vocab_size_padded = context_logits.shape[-1] context_logits = context_logits.reshape([1, -1, vocab_size_padded]) context_logits = torch.index_select(context_logits, 1, last_token_ids - 1).view( batch_size, 1, vocab_size_padded) logits = torch.cat([context_logits, generation_logits], axis=1) logits = logits.reshape(-1, num_beams, logits.shape[1], logits.shape[2]) output_file = Path(output_logits_npy) output_file.parent.mkdir(exist_ok=True, parents=True) outputs = np.array(logits.cpu().contiguous(), dtype='float32') np.save(output_file, outputs) def main(args): runtime_rank = xtrt_llm.mpi_rank() logger.set_level(args.log_level) model_name = read_model_name_from_config( Path(args.engine_dir) / "config.json") if args.tokenizer_dir is None: args.tokenizer_dir = DEFAULT_HF_MODEL_DIRS[model_name] _, other_cfg = read_config(Path(args.engine_dir) / "config.json") tp_size, pp_size = other_cfg["tp_size"], other_cfg["pp_size"] world_size = tp_size * pp_size if world_size > 1: os.environ["XCCL_GROUP_ID"] = str(runtime_rank // world_size) os.environ["XCCL_NRANKS"] = str(world_size) os.environ["XCCL_CUR_RANK"] = str(runtime_rank % world_size) os.environ["XCCL_DEVICE_ID"] = str(runtime_rank) os.environ["MP_RUN"] = str(1) tokenizer, pad_id, end_id = load_tokenizer( tokenizer_dir=args.tokenizer_dir, vocab_file=args.vocab_file, model_name=model_name, ) runner = ModelRunner.from_dir(engine_dir=args.engine_dir, lora_dir=args.lora_dir, rank=runtime_rank, debug_mode=args.debug_mode) # # An example to stop generation when the model generate " London" on first sentence, " eventually became" on second sentence stop_words_list = [["<|endoftext|>"]] stop_words_list = xtrt_llm.runtime.to_word_list_format( stop_words_list, tokenizer) stop_words_list = torch.Tensor(stop_words_list).to( torch.int32).to("cuda").contiguous() # stop_words_list = None # # An example to prevent generating " chef" on first sentence, " eventually" and " chef before" on second sentence # bad_words_list = [[" chef"], [" eventually, chef before"]] # bad_words_list = xtrt_llm.runtime.to_word_list_format(bad_words_list, tokenizer) # bad_words_list = torch.Tensor(bad_words_list).to(torch.int32).to("cuda").contiguous() bad_words_list = None if args.use_prompt_template and model_name in DEFAULT_PROMPT_TEMPLATES: prompt_template = DEFAULT_PROMPT_TEMPLATES[model_name] else: prompt_template = None if args.performance_test_scale is not None: performance_test_scale_list = args.performance_test_scale.split("E") for scale in performance_test_scale_list: xtrt_llm.logger.info(f"Running performance test with scale {scale}") import time _t_s = time.time() bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")] batch_input_ids = [ torch.from_numpy(np.zeros((seqlen, )).astype("int32")) for _ in range(bs) ] with torch.no_grad(): outputs = runner.generate( batch_input_ids, max_new_tokens=_max_output_len, max_kv_cache_length=args.max_kv_cache_length, end_id=end_id, pad_id=pad_id, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, num_beams=args.num_beams, length_penalty=args.length_penalty, repetition_penalty=args.repetition_penalty, stop_words_list=stop_words_list, bad_words_list=bad_words_list, lora_uids=args.lora_task_uids, prompt_table_path=args.prompt_table_path, prompt_tasks=args.prompt_tasks, streaming=args.streaming, output_sequence_lengths=True, return_dict=True) torch.cuda.synchronize() _t_e = time.time() xtrt_llm.logger.info( f"Total latency: {(_t_e - _t_s)* 1000 :.3f} ms") exit(0) else: batch_input_ids = parse_input( tokenizer=tokenizer, input_text=args.input_text, prompt_template=prompt_template, input_file=args.input_file, add_special_tokens=args.add_special_tokens, max_input_length=args.max_input_length, pad_id=pad_id) input_lengths = [x.size(1) for x in batch_input_ids] with torch.no_grad(): outputs = runner.generate( batch_input_ids, max_new_tokens=args.max_output_len, max_kv_cache_length=args.max_kv_cache_length, end_id=end_id, pad_id=pad_id, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, num_beams=args.num_beams, length_penalty=args.length_penalty, repetition_penalty=args.repetition_penalty, stop_words_list=stop_words_list, bad_words_list=bad_words_list, lora_uids=args.lora_task_uids, prompt_table_path=args.prompt_table_path, prompt_tasks=args.prompt_tasks, streaming=args.streaming, output_sequence_lengths=True, return_dict=True) torch.cuda.synchronize() if runtime_rank == 0: if args.streaming: for curr_outputs in throttle_generator(outputs, args.streaming_interval): output_ids = curr_outputs['output_ids'] sequence_lengths = curr_outputs['sequence_lengths'] print_output(tokenizer, output_ids, input_lengths, sequence_lengths, output_csv=args.output_csv, output_npy=args.output_npy) else: output_ids = outputs['output_ids'] sequence_lengths = outputs['sequence_lengths'] context_logits = None generation_logits = None if runner.session.gather_all_token_logits: context_logits = outputs['context_logits'] generation_logits = outputs['generation_logits'] print_output(tokenizer, output_ids, input_lengths, sequence_lengths, output_csv=args.output_csv, output_npy=args.output_npy, context_logits=context_logits, generation_logits=generation_logits, output_logits_npy=args.output_logits_npy) if __name__ == '__main__': args = parse_arguments() main(args)