This commit is contained in:
2025-08-06 15:49:14 +08:00
parent e80b916c52
commit bf00e72fb2
111 changed files with 21880 additions and 1 deletions

5
examples/gptj/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
__pycache__/
gptj_model/
*.log
*.txt
*.json

77
examples/gptj/README.md Normal file
View File

@@ -0,0 +1,77 @@
# GPT-J
This document explains how to build the [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b) model using XTRT-LLM and run on a single XPU.
## Overview
The XTRT-LLM GPT-J example
code is located in [`examples/gptj`](./). There are several main files in that folder:
* [`build.py`](./build.py) to build the [XTRT] engine(s) needed to run the GPT-J model,
* [`run.py`](./run.py) to run the inference on an input text,
## Support Matrix
* FP16
## Usage
### 1. Download weights from HuggingFace (HF) Transformers
```bash
# 1. Weights & config
git clone https://huggingface.co/EleutherAI/gpt-j-6b ./downloads/gptj-6b
pushd ./downloads/gptj-6b && \
rm -f pytorch_model.bin && \
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
popd
# 2. Vocab and merge table
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt
```
### 2. Build XTRT engine(s)
XTRT-LLM builds XTRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, XTRT-LLM will build engine(s) using
dummy weights.
Examples of build invocations:
```bash
# Build a float16 engine using HF weights.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--output_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
--model_dir=./downloads/gptj-6b 2>&1 | tee build.log
# Build a float16 engine using dummy weights, useful for performance tests.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--output_dir=./downloads/gptj-6b/trt_engines/gptj_engine_dummy_weights 2>&1 | tee build.log
```
### 3. Run
To run a XTRT-LLM GPT-J model:
```bash
python3 run.py --max_output_len=50 \
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
--hf_model_location=./downloads/gptj-6b
```

View File

@@ -0,0 +1,76 @@
# GPT-J
本文档介绍了如何使用昆仑芯XTRT-LLM在单XPU上构建和运行[GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b)模型。
## 概述
XTRT-LLM GPT-J 示例代码位于 [`examples/gptj`](./)。 此文件夹中有以下几个主要文件:
* [`build.py`](./build.py) 构建运行GPT-J模型所需的XTRT引擎
* [`run.py`](./run.py) 基于输入的文字进行推理
## 支持的矩阵
* FP16
## 使用说明
### 1.从HuggingFaceHF Transformers下载权重
```bash
# 1. Weights & config
git clone https://huggingface.co/EleutherAI/gpt-j-6b ./downloads/gptj-6b
pushd ./downloads/gptj-6b && \
rm -f pytorch_model.bin && \
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
popd
# 2. Vocab and merge table
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt
```
### 2. 构建XTRT引擎
XTRT-LLM从HF checkpoint构建XTRT引擎。如果未指定checkpoint目录XTRT-LLM将使用伪权重构建引擎。
构建调用示例:
```bash
# Build a float16 engine using HF weights.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--output_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
--model_dir=./downloads/gptj-6b 2>&1 | tee build.log
# Build a float16 engine using dummy weights, useful for performance tests.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
python3 build.py --dtype=float16 \
--log_level=verbose \
--enable_context_fmha \
--use_gpt_attention_plugin float16 \
--use_gemm_plugin float16 \
--max_batch_size=32 \
--max_input_len=1919 \
--max_output_len=128 \
--output_dir=./downloads/gptj-6b/trt_engines/gptj_engine_dummy_weights 2>&1 | tee build.log
```
### 3. 运行
要运行XTRT-LLM GPT-J模型请执行以下操作
```bash
python3 run.py --max_output_len=50 \
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
--hf_model_location=./downloads/gptj-6b
```

489
examples/gptj/build.py Normal file
View File

@@ -0,0 +1,489 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
import tvm.tensorrt as trt
import torch
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM
from weight import get_scaling_factors, load_from_awq_gpt_j, load_from_hf_gpt_j
import xtrt_llm
from xtrt_llm.builder import Builder
from xtrt_llm.logger import logger
from xtrt_llm.mapping import Mapping
from xtrt_llm.models import (weight_only_groupwise_quantize,
weight_only_quantize)
from xtrt_llm.network import net_guard
from xtrt_llm.plugin.plugin import ContextFMHAType
from xtrt_llm.quantization import QuantMode
MODEL_NAME = "gptj"
hf_gpt = None
awq_gptj_config = None
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
engine.serialize(path)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Engine serialized. Total time: {t}')
def parse_arguments(args):
parser = argparse.ArgumentParser()
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument(
'--model_dir',
type=str,
default=None,
help='The path to HF GPT-J model / checkpoints to read weights from')
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float16', 'float32'])
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument(
'--timing_cache',
type=str,
default='model.cache',
help=
'The path of to read timing cache from, will be ignored if the file does not exist'
)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--vocab_size', type=int, default=50401)
parser.add_argument('--n_layer', type=int, default=28)
parser.add_argument('--n_positions', type=int, default=2048)
parser.add_argument('--n_embd', type=int, default=4096)
parser.add_argument('--n_head', type=int, default=16)
parser.add_argument('--hidden_act', type=str, default='gelu')
parser.add_argument('--rotary_dim', type=int, default=64)
parser.add_argument('--max_batch_size', type=int, default=256)
parser.add_argument('--max_input_len', type=int, default=200)
parser.add_argument('--max_output_len', type=int, default=200)
parser.add_argument('--max_beam_width', type=int, default=1)
parser.add_argument('--use_gpt_attention_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--use_gemm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--use_weight_only_quant_matmul_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16'])
parser.add_argument('--use_layernorm_plugin',
nargs='?',
const='float16',
type=str,
default=False,
choices=['float16', 'float32'])
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument(
'--output_dir',
type=str,
default='gpt_outputs',
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
parser.add_argument('--enable_fp8', default=False, action='store_true')
parser.add_argument(
'--quantized_fp8_model_path',
type=str,
default=None,
help='Path of a quantized model checkpoint that in .npz format')
parser.add_argument(
'--fp8_kv_cache',
default=False,
action="store_true",
help=
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
)
parser.add_argument(
'--use_inflight_batching',
action="store_true",
default=False,
help="Activates inflight batching mode of gptAttentionPlugin.")
parser.add_argument(
'--enable_two_optimization_profiles',
default=False,
action='store_true',
help=
"Enables two optimization profiles during engine build, for context and generate phases. By default (and for inflight batching too), only 1 opt profile."
)
parser.add_argument(
'--paged_kv_cache',
action="store_true",
default=False,
help=
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
)
parser.add_argument('--tokens_per_block',
type=int,
default=64,
help='Number of tokens per block in paged KV cache')
parser.add_argument(
'--max_num_tokens',
type=int,
default=None,
help='Define the max number of tokens supported by the engine')
parser.add_argument(
'--per_group',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The falg is built for GPTQ/AWQ quantization.')
parser.add_argument(
'--use_weight_only',
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
)
parser.add_argument(
'--strongly_typed',
default=False,
action="store_true",
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
args = parser.parse_args(args)
logger.set_level(args.log_level)
if not args.remove_input_padding:
if args.use_gpt_attention_plugin:
logger.warning(
f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
)
if args.model_dir is not None:
global hf_gpt
if args.use_weight_only and args.weight_only_precision == 'int4' and args.per_group:
logger.info(f'Loading AWQ GPTJ model from {args.model_dir}...')
global awq_gptj_config
with open(args.model_dir + "/config.json",
encoding='utf-8') as config_file:
awq_gptj_config = json.load(config_file)
args.n_embd = awq_gptj_config['n_embd']
args.n_head = awq_gptj_config['n_head']
args.n_layer = awq_gptj_config['n_layer']
args.n_positions = awq_gptj_config['n_positions']
args.vocab_size = awq_gptj_config['vocab_size']
if args.vocab_size % 64 != 0:
args.vocab_size = int(
(awq_gptj_config['vocab_size'] + 63) / 64) * 64
print(
"vocab_size is {}, to use awq we pad it to {}.".format(
awq_gptj_config['vocab_size'], args.vocab_size))
hf_gpt = torch.load(args.model_dir + "/gptj_quantized.pth")
else:
logger.info(f'Loading HF GPTJ model from {args.model_dir}...')
hf_gpt = AutoModelForCausalLM.from_pretrained(args.model_dir)
args.n_embd = hf_gpt.config.n_embd
args.n_head = hf_gpt.config.n_head
args.n_layer = hf_gpt.config.n_layer
args.n_positions = hf_gpt.config.n_positions
args.vocab_size = hf_gpt.config.vocab_size
assert not (args.use_weight_only and args.weight_only_precision
== 'int8'), "Not support int8 weight only."
assert not (args.use_weight_only and args.weight_only_precision == 'int4'
and args.per_group
== False), "We only support AWQ for int4 weight only."
if args.use_weight_only:
args.quant_mode = QuantMode.use_weight_only(
args.weight_only_precision == 'int4')
else:
args.quant_mode = QuantMode(0)
if args.fp8_kv_cache:
assert (
args.use_gpt_attention_plugin
), "You have to use GPT attention plugin when fp8 KV cache is set"
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
if args.enable_fp8:
args.quant_mode = args.quant_mode.set_fp8_qdq()
if args.use_inflight_batching:
if not args.use_gpt_attention_plugin:
args.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
)
if not args.remove_input_padding:
args.remove_input_padding = True
logger.info(
"Using remove input padding for inflight batching mode.")
if not args.paged_kv_cache:
args.paged_kv_cache = True
logger.info("Using paged KV cache for inflight batching mode.")
if args.max_num_tokens is not None:
assert args.enable_context_fmha
if args.remove_input_padding or args.use_inflight_batching or args.paged_kv_cache:
assert (
not args.enable_two_optimization_profiles
), "Only 1 opt profile supported for inflight batching and paged kv cache."
return args
def build_rank_engine(builder: Builder,
builder_config: xtrt_llm.builder.BuilderConfig,
engine_name, rank, args):
'''
@brief: Build the engine on the given rank.
@param rank: The rank to build the engine.
@param args: The cmd line arguments.
@return: The built engine.
'''
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
# Initialize Module
xtrt_llm_gpt = xtrt_llm.models.GPTJForCausalLM(
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
rotary_dim=args.rotary_dim,
dtype=kv_dtype,
logits_dtype=args.logits_dtype,
mapping=Mapping(world_size=args.world_size,
rank=rank,
tp_size=args.world_size), # TP only
quant_mode=args.quant_mode)
if args.use_weight_only_quant_matmul_plugin:
xtrt_llm_gpt = weight_only_quantize(xtrt_llm_gpt)
if args.use_weight_only and args.weight_only_precision == 'int4':
if args.per_group:
xtrt_llm_gpt = weight_only_groupwise_quantize(
model=xtrt_llm_gpt,
quant_mode=QuantMode.from_description(
quantize_weights=True,
quantize_activations=False,
per_token=False,
per_channel=False,
per_group=True,
use_int4_weights=True),
group_size=128,
zero=False,
pre_quant_scale=True,
exclude_modules=[],
)
if args.model_dir is not None:
assert hf_gpt is not None, f'Could not load weights from hf_gpt model as it is not loaded yet.'
if args.enable_fp8:
gptj_scaling_factors = get_scaling_factors(
args.quantized_fp8_model_path, args.n_layer, args.quant_mode)
else:
gptj_scaling_factors = None
if args.use_weight_only and args.weight_only_precision == 'int4' and args.per_group:
load_from_awq_gpt_j(xtrt_llm_gpt,
awq_gpt_j=hf_gpt,
config=awq_gptj_config,
dtype=args.dtype)
else:
load_from_hf_gpt_j(xtrt_llm_gpt,
hf_gpt,
args.dtype,
scaling_factors=gptj_scaling_factors)
# Module -> Network
network = builder.create_network()
network.trt_network.name = engine_name
if args.use_gpt_attention_plugin:
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
if args.enable_context_fmha_fp32_acc:
network.plugin_config.set_context_fmha(
ContextFMHAType.enabled_with_fp32_acc)
if args.use_weight_only_quant_matmul_plugin:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.use_weight_only_quant_matmul_plugin)
if args.use_weight_only:
if args.per_group:
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
dtype='float16')
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype)
if args.remove_input_padding:
network.plugin_config.enable_remove_input_padding()
if args.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
with net_guard(network):
# Prepare
network.set_named_parameters(xtrt_llm_gpt.named_parameters())
# Forward
inputs = xtrt_llm_gpt.prepare_inputs(
args.max_batch_size,
args.max_input_len,
args.max_output_len,
True,
args.max_beam_width,
max_num_tokens=args.max_num_tokens,
enable_two_optimization_profiles=args.
enable_two_optimization_profiles)
xtrt_llm_gpt(*inputs)
# xtrt_llm.graph_rewriting.optimize(network)
engine = None
# Network -> Engine
engine = builder.build_engine(network, builder_config, compiler="gr")
if rank == 0:
config_path = os.path.join(args.output_dir, 'config.json')
builder.save_config(builder_config, config_path)
return engine
def build(rank, args):
# torch.cuda.set_device(rank % args.gpus_per_node)
xtrt_llm.logger.set_level(args.log_level)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# when doing serializing build, all ranks share one engine
builder = Builder()
cache = None
for cur_rank in range(args.world_size):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
builder_config = builder.create_builder_config(
name=MODEL_NAME,
precision=args.dtype,
timing_cache=args.timing_cache if cache is None else cache,
tensor_parallel=args.world_size, # TP only
parallel_build=args.parallel_build,
num_layers=args.n_layer,
num_heads=args.n_head,
hidden_size=args.n_embd,
inter_size=args.n_embd * 4,
vocab_size=args.vocab_size,
hidden_act=args.hidden_act,
max_position_embeddings=args.n_positions,
max_batch_size=args.max_batch_size,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
max_num_tokens=args.max_num_tokens,
fp8=args.enable_fp8,
quant_mode=args.quant_mode,
strongly_typed=args.strongly_typed)
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
cur_rank)
engine = build_rank_engine(builder, builder_config, engine_name,
cur_rank, args)
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
# if cur_rank == 0:
# # Use in-memory timing cache for multiple builder passes.
# if not args.parallel_build:
# cache = builder_config.xtrt_builder_config.get_timing_cache()
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
# if rank == 0:
# ok = builder.save_timing_cache(
# builder_config, os.path.join(args.output_dir, "model.cache"))
# assert ok, "Failed to save timing cache."
def run_build(args=None):
args = parse_arguments(args)
tik = time.time()
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False
logger.info('Serially build TensorRT engines.')
build(0, args)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Total time of building all {args.world_size} engines: {t}')
if __name__ == '__main__':
run_build()

137
examples/gptj/quantize.py Normal file
View File

@@ -0,0 +1,137 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from examples/quantization/hf_ptq.py
"""
import argparse
import random
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from xtrt_llm._utils import str_dtype_to_torch
from xtrt_llm.logger import logger
from xtrt_llm.models.quantized.ammo import quantize_and_export
def get_calib_dataloader(data="cnn_dailymail",
tokenizer=None,
batch_size=1,
calib_size=512,
block_size=512):
print("Loading calibration dataset")
if data == "pileval":
dataset = load_dataset(
"json",
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
split="train")
dataset = dataset["text"][:calib_size]
elif data == "cnn_dailymail":
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
dataset = dataset["article"][:calib_size]
else:
raise NotImplementedError
# NOTE truncate dataset to n_positions for RoPE in GPT-J
batch_encoded = tokenizer.batch_encode_plus(
dataset,
return_tensors="pt",
padding=True,
truncation=True,
max_length=block_size,
)
batch_encoded = batch_encoded["input_ids"]
batch_encoded = batch_encoded.cuda()
calib_dataloader = DataLoader(batch_encoded,
batch_size=batch_size,
shuffle=False)
return calib_dataloader
def get_tokenizer(ckpt_path, **kwargs):
logger.info(f"Loading tokenizer from {ckpt_path}")
tokenizer = AutoTokenizer.from_pretrained(ckpt_path,
padding_side="left",
**kwargs)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def get_model(ckpt_path, dtype="float16"):
logger.info(f"Loading model from {ckpt_path}")
torch_dtype = str_dtype_to_torch(dtype)
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch_dtype,
)
model.eval()
model = model.to(memory_format=torch.channels_last)
return model
def get_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_dir",
type=str,
required=True,
help="Directory of a HF model checkpoint")
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument("--qformat",
type=str,
choices=['fp8'],
default='fp8',
help='Quantization format.')
parser.add_argument("--calib_size",
type=int,
default=512,
help="Number of samples for calibration.")
parser.add_argument("--export_path", default="exported_model")
parser.add_argument('--seed', type=int, default=None, help='Random seed')
args = parser.parse_args()
return args
def main():
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
args = get_args()
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
tokenizer = get_tokenizer(args.model_dir)
model = get_model(args.model_dir, args.dtype)
calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
calib_size=args.calib_size)
model = quantize_and_export(model,
qformat=args.qformat,
calib_dataloader=calib_dataloader,
export_path=args.export_path)
if __name__ == "__main__":
main()

284
examples/gptj/run.py Normal file
View File

@@ -0,0 +1,284 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import csv
import json
from pathlib import Path
import numpy as np
import torch
from utils import token_encoder
import xtrt_llm
from xtrt_llm.quantization import QuantMode
from xtrt_llm.runtime import ModelConfig, SamplingConfig
from build import get_engine_name # isort:skip
# GPT3 Related variables
# Reference : https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/gpt_sample.py
MERGES_FILE = "merges.txt"
VOCAB_FILE = "vocab.json"
PAD_ID = 50256
START_ID = 50256
END_ID = 50256
def read_config(config_path: Path):
with open(config_path, 'r') as f:
config = json.load(f)
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
remove_input_padding = config['plugin_config']['remove_input_padding']
world_size = config['builder_config']['tensor_parallel']
assert world_size == xtrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
num_heads = config['builder_config']['num_heads'] // world_size
hidden_size = config['builder_config']['hidden_size'] // world_size
vocab_size = config['builder_config']['vocab_size']
num_layers = config['builder_config']['num_layers']
quant_mode = QuantMode(config['builder_config']['quant_mode'])
paged_kv_cache = config['plugin_config']['paged_kv_cache']
tokens_per_block = config['plugin_config']['tokens_per_block']
dtype = config['builder_config']['precision']
model_config = ModelConfig(num_heads=num_heads,
num_kv_heads=num_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
quant_mode=quant_mode,
dtype=dtype)
max_input_len = config['builder_config']['max_input_len']
return model_config, world_size, dtype, max_input_len
def parse_input(input_text: str, input_file: str, tokenizer, pad_id: int,
remove_input_padding: bool):
input_tokens = []
if input_file is None:
input_tokens.append(tokenizer.encode(input_text))
else:
if input_file.endswith('.csv'):
with open(input_file, 'r') as csv_file:
csv_reader = csv.reader(csv_file, delimiter=',')
for line in csv_reader:
input_tokens.append(np.array(line, dtype='int32'))
elif input_file.endswith('.npy'):
inputs = np.load(input_file)
for row in inputs:
row = row[row != pad_id]
input_tokens.append(row)
else:
print('Input file format not supported.')
raise SystemExit
input_ids = None
input_lengths = torch.tensor([len(x) for x in input_tokens],
dtype=torch.int32).cuda()
if remove_input_padding:
input_ids = np.concatenate(input_tokens)
input_ids = torch.tensor(input_ids, dtype=torch.int32,
device='cuda').unsqueeze(0)
else:
input_ids = torch.nested.to_padded_tensor(
torch.nested.nested_tensor(input_tokens, dtype=torch.int32),
pad_id).cuda()
return input_ids, input_lengths
def print_output(output_ids, cum_log_probs, input_lengths, sequence_lengths,
tokenizer, output_csv, output_npy):
num_beams = output_ids.size(1)
if output_csv is None and output_npy is None:
for b in range(input_lengths.size(0)):
inputs = output_ids[b][0][:input_lengths[b]].tolist()
input_text = tokenizer.decode(inputs)
print(f'Input idx: {b}')
print(f'Input: \"{input_text}\"')
for beam in range(num_beams):
output_begin = input_lengths[b]
output_end = sequence_lengths[b][beam]
outputs = output_ids[b][beam][output_begin:output_end].tolist()
output_text = tokenizer.decode(outputs)
if num_beams > 1:
cum_log_prob = cum_log_probs[b][beam]
print(f'Output idx: {b}, beam {beam} (cum_log_prob: {cum_log_prob})')
print(f'Output: \"{output_text}\"')
else:
print(f'Output idx:{b}')
print(f'Output: \"{output_text}\"')
output_ids = output_ids.reshape((-1, output_ids.size(2)))
if output_csv is not None:
output_file = Path(output_csv)
output_file.parent.mkdir(exist_ok=True, parents=True)
outputs = output_ids.tolist()
with open(output_file, 'w') as csv_file:
writer = csv.writer(csv_file, delimiter=',')
writer.writerows(outputs)
if output_npy is not None:
output_file = Path(output_npy)
output_file.parent.mkdir(exist_ok=True, parents=True)
outputs = np.array(output_ids.cpu().contiguous(), dtype='int32')
np.save(output_file, outputs)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--max_output_len', type=int, required=True)
parser.add_argument('--log_level', type=str, default='error')
parser.add_argument('--engine_dir', type=str, default='gpt_outputs')
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('--min_length', type=int, default=1)
parser.add_argument('--input_text',
type=str,
default='Born in north-east France, Soyer trained as a')
parser.add_argument(
'--input_tokens',
dest='input_file',
type=str,
help=
'CSV or Numpy file containing tokenized input. Alternative to text input.',
default=None)
parser.add_argument('--output_csv',
type=str,
help='CSV file where the tokenized output is stored.',
default=None)
parser.add_argument('--output_npy',
type=str,
help='Numpy file where the tokenized output is stored.',
default=None)
parser.add_argument(
'--hf_model_location',
type=str,
default="gptj_model",
help=
'The hugging face model location stores the merges.txt and vocab.json to create tokenizer'
)
parser.add_argument(
'--performance_test_scale',
type=str,
help=
"Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)",
default="")
return parser.parse_args()
def generate(
max_output_len: int,
log_level: str = 'error',
engine_dir: str = 'gpt_outputs',
input_text: str = 'Born in north-east France, Soyer trained as a',
input_file: str = None,
output_csv: str = None,
output_npy: str = None,
hf_model_location: str = 'gptj',
num_beams: int = 1,
min_length: int = 1,
performance_test_scale: str = "",
):
xtrt_llm.logger.set_level(log_level)
engine_dir = Path(engine_dir)
config_path = engine_dir / 'config.json'
model_config, world_size, dtype, max_input_len = read_config(config_path)
runtime_rank = xtrt_llm.mpi_rank()
runtime_mapping = xtrt_llm.Mapping(world_size,
runtime_rank,
tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
vocab_file = Path(hf_model_location) / VOCAB_FILE
merges_file = Path(hf_model_location) / MERGES_FILE
assert vocab_file.is_file(), f"{vocab_file} does not exist"
assert merges_file.is_file(), f"{merges_file} does not exist"
tokenizer = token_encoder.get_encoder(vocab_file, merges_file)
sampling_config = SamplingConfig(end_id=END_ID,
pad_id=PAD_ID,
num_beams=num_beams,
min_length=min_length)
engine_name = get_engine_name('gptj', dtype, world_size, runtime_rank)
# serialize_path = Path(engine_dir) / engine_name
serialize_path = str(engine_dir) + "/" + engine_name
# with open(serialize_path, 'rb') as f:
# engine_buffer = f.read()
decoder = xtrt_llm.runtime.GenerationSession(model_config,
serialize_path,
runtime_mapping,
debug_mode=False,
debug_tensors_to_save=None)
input_ids, input_lengths = parse_input(input_text, input_file, tokenizer,
PAD_ID,
model_config.remove_input_padding)
if performance_test_scale != "":
performance_test_scale_list = performance_test_scale.split("E")
for scale in performance_test_scale_list:
xtrt_llm.logger.info(f"Running performance test with scale {scale}")
bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")]
_input_ids = torch.from_numpy(
np.zeros((bs, seqlen)).astype("int32")).cuda()
_input_lengths = torch.from_numpy(
np.full((bs, ), seqlen).astype("int32")).cuda()
_max_input_length = torch.max(_input_lengths).item()
decoder.setup(_input_lengths.size(0), _max_input_length,
_max_output_len, num_beams)
_output_gen_ids = decoder.decode(_input_ids,
_input_lengths,
sampling_config,
output_sequence_lengths=True,
return_dict=True)
max_input_length = torch.max(input_lengths).item()
decoder.setup(input_lengths.size(0),
max_input_length,
max_output_len,
beam_width=num_beams)
outputs = decoder.decode(input_ids,
input_lengths,
sampling_config,
output_sequence_lengths=True,
return_dict=True)
output_ids = outputs['output_ids']
sequence_lengths = outputs['sequence_lengths']
torch.cuda.synchronize()
cum_log_probs = decoder.cum_log_probs if num_beams > 1 else None
if runtime_rank == 0:
print_output(output_ids, cum_log_probs, input_lengths, sequence_lengths,
tokenizer, output_csv, output_npy)
if __name__ == '__main__':
args = parse_arguments()
generate(**vars(args))

8
examples/gptj/run.sh Normal file
View File

@@ -0,0 +1,8 @@
XMLIR_D_XPU_L3_SIZE=0 \
python3 run.py \
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
--hf_model_location=./downloads/gptj-6b \
--max_output_len=2048 \
--performance_test_scale=1x512x512E1x1024x1024E1x2000x64E1x2048x2048E2x512x512E2x1024x1024E2x2000x64E2x2048x2048E4x512x512E\
4x1024x1024E4x2000x64E4x2048x2048E8x512x512E8x1024x1024E8x2000x64 \
--log_level=info

409
examples/gptj/summarize.py Normal file
View File

@@ -0,0 +1,409 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import copy
import json
import os
import random
import numpy as np
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForCausalLM, AutoTokenizer
import xtrt_llm
import xtrt_llm.profiler as profiler
from xtrt_llm.logger import logger
from xtrt_llm.quantization import QuantMode
from build import get_engine_name # isort:skip
def TRTGPTJ(args, config):
dtype = config['builder_config']['precision']
world_size = config['builder_config']['tensor_parallel']
assert world_size == xtrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
world_size = config['builder_config']['tensor_parallel']
num_heads = config['builder_config']['num_heads'] // world_size
hidden_size = config['builder_config']['hidden_size'] // world_size
vocab_size = config['builder_config']['vocab_size']
num_layers = config['builder_config']['num_layers']
use_gpt_attention_plugin = bool(
config['plugin_config']['gpt_attention_plugin'])
remove_input_padding = config['plugin_config']['remove_input_padding']
quant_mode = QuantMode(config['builder_config'].get('quant_mode', 0))
paged_kv_cache = config['plugin_config']['paged_kv_cache']
tokens_per_block = config['plugin_config']['tokens_per_block']
model_config = xtrt_llm.runtime.ModelConfig(
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_heads,
hidden_size=hidden_size,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
quant_mode=quant_mode,
dtype=dtype)
runtime_rank = xtrt_llm.mpi_rank()
runtime_mapping = xtrt_llm.Mapping(world_size,
runtime_rank,
tp_size=world_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
engine_name = get_engine_name('gptj', dtype, world_size, runtime_rank)
serialize_path = os.path.join(args.engine_dir, engine_name)
xtrt_llm.logger.set_level(args.log_level)
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
decoder = xtrt_llm.runtime.GenerationSession(model_config,
engine_buffer,
runtime_mapping)
return decoder
def main(args):
runtime_rank = xtrt_llm.mpi_rank()
logger.set_level(args.log_level)
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
test_trt_llm = args.test_trt_llm
model_dir = args.model_dir
tokenizer = AutoTokenizer.from_pretrained(model_dir,
padding_side='left',
model_max_length=2048,
truncation=True)
tokenizer.pad_token = tokenizer.eos_token
dataset_cnn = load_dataset("ccdv/cnn_dailymail",
'3.0.0',
cache_dir=args.dataset_path)
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
max_batch_size = args.batch_size
# runtime parameters
# repetition_penalty = 1
top_k = args.top_k
output_len = args.output_len
test_token_num = 923
# top_p = 0.0
# random_seed = 5
temperature = 1
num_beams = args.num_beams
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
if test_trt_llm:
xtrt_llm_gpt = TRTGPTJ(args, config)
if test_hf:
model = AutoModelForCausalLM.from_pretrained(model_dir)
model.cuda()
if args.data_type == 'fp16':
model.half()
def summarize_xtrt_llm(datapoint):
batch_size = len(datapoint['article'])
line = copy.copy(datapoint['article'])
line_encoded = []
input_lengths = []
for i in range(batch_size):
line[i] = line[i] + ' TL;DR: '
line[i] = line[i].strip()
line[i] = line[i].replace(" n't", "n't")
input_id = tokenizer.encode(line[i],
return_tensors='pt').type(torch.int32)
input_id = input_id[:, -test_token_num:]
line_encoded.append(input_id)
input_lengths.append(input_id.shape[-1])
# do padding, should move outside the profiling to prevent the overhead
max_length = max(input_lengths)
if xtrt_llm_gpt.remove_input_padding:
line_encoded = [
torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded
]
else:
# do padding, should move outside the profiling to prevent the overhead
for i in range(batch_size):
pad_size = max_length - input_lengths[i]
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
line_encoded[i] = torch.cat(
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
axis=-1)
line_encoded = torch.cat(line_encoded, axis=0).cuda()
input_lengths = torch.tensor(input_lengths,
dtype=torch.int32).cuda()
sampling_config = xtrt_llm.runtime.SamplingConfig(
end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams)
with torch.no_grad():
xtrt_llm_gpt.setup(batch_size,
max_context_length=max_length,
max_new_tokens=output_len,
beam_width=num_beams)
if xtrt_llm_gpt.remove_input_padding:
output_ids = xtrt_llm_gpt.decode_batch(
line_encoded, sampling_config)
else:
output_ids = xtrt_llm_gpt.decode(
line_encoded,
input_lengths,
sampling_config,
)
torch.cuda.synchronize()
# Extract a list of tensors of shape beam_width x output_ids.
output_beams_list, output_ids_list = [], []
if xtrt_llm_gpt.mapping.is_first_pp_rank():
output_beams_list = [
tokenizer.batch_decode(output_ids[batch_idx, :,
input_lengths[batch_idx]:],
skip_special_tokens=True)
for batch_idx in range(batch_size)
]
output_ids_list = [
output_ids[batch_idx, :, input_lengths[batch_idx]:]
for batch_idx in range(batch_size)
]
return output_beams_list, output_ids_list
def summarize_hf(datapoint):
batch_size = len(datapoint['article'])
if batch_size > 1:
logger.warning(
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
)
line = copy.copy(datapoint['article'])
for i in range(batch_size):
line[i] = line[i] + ' TL;DR: '
line[i] = line[i].strip()
line[i] = line[i].replace(" n't", "n't")
line_encoded = tokenizer(line,
return_tensors='pt',
padding=True,
truncation=True)["input_ids"].type(torch.int64)
line_encoded = line_encoded[:, -test_token_num:]
line_encoded = line_encoded.cuda()
with torch.no_grad():
output = model.generate(line_encoded,
max_length=len(line_encoded[0]) +
output_len,
top_k=top_k,
temperature=temperature,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
num_beams=num_beams,
num_return_sequences=num_beams,
early_stopping=True)
tokens_list = output[:, len(line_encoded[0]):].tolist()
output = output.reshape([batch_size, num_beams, -1])
output_lines_list = [
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
skip_special_tokens=True)
for i in range(num_beams)
]
return output_lines_list, tokens_list
if test_trt_llm:
datapoint = dataset_cnn['test'][0:1]
summary, _ = summarize_xtrt_llm(datapoint)
if runtime_rank == 0:
logger.info(
"---------------------------------------------------------")
logger.info("XTRT-LLM Generated : ")
logger.info(f" Article : {datapoint['article']}")
logger.info(f"\n Highlights : {datapoint['highlights']}")
logger.info(f"\n Summary : {summary}")
logger.info(
"---------------------------------------------------------")
if test_hf:
datapoint = dataset_cnn['test'][0:1]
summary, _ = summarize_hf(datapoint)
logger.info("---------------------------------------------------------")
logger.info("HF Generated : ")
logger.info(f" Article : {datapoint['article']}")
logger.info(f"\n Highlights : {datapoint['highlights']}")
logger.info(f"\n Summary : {summary}")
logger.info("---------------------------------------------------------")
xtrt_llm_result = [[] for _ in range(num_beams)]
hf_result = [[] for _ in range(num_beams)]
ite_count = 0
data_point_idx = 0
# Support running the set with different order to verify correctness
test_idx = list(
range(min(len(dataset_cnn['test']), max_batch_size * args.max_ite)))
random.seed(args.random_seed)
random.shuffle(test_idx)
while (data_point_idx < len(dataset_cnn['test'])) and (ite_count <
args.max_ite):
if runtime_rank == 0:
logger.debug(
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
)
datapoint = dataset_cnn['test'][test_idx[data_point_idx:(
data_point_idx + max_batch_size)]]
if test_trt_llm:
profiler.start('xtrt_llm')
summary_xtrt_llm, tokens_xtrt_llm = summarize_xtrt_llm(
datapoint)
profiler.stop('xtrt_llm')
if test_hf:
profiler.start('hf')
summary_hf, tokens_hf = summarize_hf(datapoint)
profiler.stop('hf')
if runtime_rank == 0:
if test_trt_llm:
for batch_idx in range(len(summary_xtrt_llm)):
for beam_idx in range(num_beams):
xtrt_llm_result[beam_idx].append(
tuple([
datapoint['id'][batch_idx],
summary_xtrt_llm[batch_idx][beam_idx],
datapoint['highlights'][batch_idx]
]))
if test_hf:
for beam_idx in range(num_beams):
for batch_idx in range(len(summary_hf[beam_idx])):
hf_result[beam_idx].append(
tuple([
datapoint['id'][batch_idx],
summary_hf[beam_idx][batch_idx],
datapoint['highlights'][batch_idx]
]))
logger.debug('-' * 100)
logger.debug(f"Article : {datapoint['article']}")
if test_trt_llm:
logger.debug(f'XTRT-LLM Summary: {summary_xtrt_llm}')
if test_hf:
logger.debug(f'HF Summary: {summary_hf}')
logger.debug(f"highlights : {datapoint['highlights']}")
data_point_idx += max_batch_size
ite_count += 1
if runtime_rank == 0:
if test_trt_llm:
np.random.seed(0) # rouge score use sampling to compute the score
logger.info(
f'XTRT-LLM (total latency: {profiler.elapsed_time_in_sec("xtrt_llm")} sec)'
)
for beam_idx in range(num_beams):
# Because 'rouge' uses sampling to compute the scores, the scores
# would be different when the results are same with different order.
# So, sorting them first to prevent this issue.
metric_xtrt_llm = load_metric("rouge")
metric_xtrt_llm.seed = 0
beams_results = sorted(xtrt_llm_result[beam_idx])
for j in range(len(beams_results)):
metric_xtrt_llm.add_batch(
predictions=[beams_results[j][1]],
references=[beams_results[j][2]])
logger.info(f"XTRT-LLM beam {beam_idx} result")
computed_metrics_xtrt_llm = metric_xtrt_llm.compute()
for key in computed_metrics_xtrt_llm.keys():
logger.info(
f' {key} : {computed_metrics_xtrt_llm[key].mid[2]*100}'
)
if args.check_accuracy and beam_idx == 0:
assert computed_metrics_xtrt_llm['rouge1'].mid[
2] * 100 > args.xtrt_llm_rouge1_threshold
if test_hf:
np.random.seed(0) # rouge score use sampling to compute the score
logger.info(
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
)
for beam_idx in range(num_beams):
metric_tensorrt_hf = load_metric("rouge")
metric_tensorrt_hf.seed = 0
beams_results = sorted(hf_result[beam_idx])
for j in range(len(beams_results)):
metric_tensorrt_hf.add_batch(
predictions=[beams_results[j][1]],
references=[beams_results[j][2]])
logger.info(f"HF beam {beam_idx} result")
computed_metrics_hf = metric_tensorrt_hf.compute()
for key in computed_metrics_hf.keys():
logger.info(
f' {key} : {computed_metrics_hf[key].mid[2]*100}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, default='EleutherAI/gpt-j-6B')
parser.add_argument('--test_hf', action='store_true')
parser.add_argument('--test_trt_llm', action='store_true')
parser.add_argument('--data_type',
type=str,
choices=['fp32', 'fp16'],
default='fp32')
parser.add_argument('--dataset_path', type=str, default='')
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--engine_dir', type=str, default='gptj_engine')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--max_ite', type=int, default=20)
parser.add_argument('--output_len', type=int, default=100)
parser.add_argument('--check_accuracy', action='store_true')
parser.add_argument('--xtrt_llm_rouge1_threshold',
type=float,
default=15.0)
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('--top_k', type=int, default=1)
parser.add_argument('--random_seed', type=int, default=0)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,178 @@
"""Byte pair encoding utilities"""
# Modified MIT License
# Software Copyright (c) 2019 OpenAI
# We dont claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
# We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
# The above copyright notice and this permission notice need not be included
# with content created by the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from functools import lru_cache
import regex as re
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"),
ord("~") + 1)) + list(range(
ord("¡"),
ord("¬") + 1)) + list(range(ord("®"),
ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'):
self.encoder = encoder
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i +
1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c]
for c in text]).decode('utf-8', errors=self.errors)
return text
def batch_decode(self, output):
ret = []
for tokens in output:
ret.append(self.decode(tokens))
return ret
def get_encoder(vocab_file, bpe_file):
with open(vocab_file, 'r', encoding="utf-8") as f:
encoder = json.load(f)
with open(bpe_file, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [
tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]
]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)

455
examples/gptj/weight.py Normal file
View File

@@ -0,0 +1,455 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import xtrt_llm
import xtrt_llm.logger as logger
from xtrt_llm._utils import str_dtype_to_torch
from xtrt_llm.models import GPTJForCausalLM
from xtrt_llm.models.quantized.quant import get_dummy_quant_scales
from xtrt_llm.quantization import QuantMode
def get_scaling_factors(
model_path: Union[str, Path],
num_layers: int,
quant_mode: Optional[QuantMode] = None,
) -> Optional[Dict[str, List[int]]]:
""" Get the scaling factors for GPT-J model
Returns a dictionary of scaling factors for the selected layers of the
GPT-J model.
Args:
model_path (str): Path to the quantized GPT-J model
layers (list): List of layers to get the scaling factors for. If None,
all layers are selected.
Returns:
dict: Dictionary of scaling factors for the selected layers of the
GPT-J model.
example:
{
'qkv_act': qkv_act_scale,
'qkv_weights': qkv_weights_scale,
'qkv_output' : qkv_outputs_scale,
'dense_act': dense_act_scale,
'dense_weights': dense_weights_scale,
'fc_act': fc_act_scale,
'fc_weights': fc_weights_scale,
'proj_act': proj_act_scale,
'proj_weights': proj_weights_scale,
}
"""
if model_path is None:
logger.warning(f"--quantized_fp8_model_path not specified. "
f"Initialize quantization scales automatically.")
return get_dummy_quant_scales(num_layers)
weight_dict = np.load(model_path)
# yapf: disable
scaling_factor = {
'qkv_act': [],
'qkv_weights': [],
'qkv_output': [],
'dense_act': [],
'dense_weights': [],
'fc_act': [],
'fc_weights': [],
'proj_act': [],
'proj_weights': [],
}
for layer in range(num_layers):
scaling_factor['qkv_act'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
))
scaling_factor['qkv_weights'].append(max(
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
))
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
# Not calibrarting KV cache.
scaling_factor['qkv_output'].append(1.0)
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
# yapf: enable
for k, v in scaling_factor.items():
assert len(v) == num_layers, \
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
return scaling_factor
def load_from_hf_gpt_j(xtrt_llm_gpt_j: GPTJForCausalLM,
hf_gpt_j,
dtype="float32",
scaling_factors=None):
hf_model_gptj_block_names = [
"ln_1.weight",
"ln_1.bias",
"mlp.fc_in.weight",
"mlp.fc_in.bias",
"mlp.fc_out.weight",
"mlp.fc_out.bias",
]
xtrt_llm_model_gptj_block_names = [
"input_layernorm.weight",
"input_layernorm.bias",
"mlp.fc.weight",
"mlp.fc.bias",
"mlp.proj.weight",
"mlp.proj.bias",
]
quant_mode = getattr(xtrt_llm_gpt_j, 'quant_mode', QuantMode(0))
xtrt_llm.logger.info('Loading weights from HF GPT-J...')
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
hf_gpt_j_state_dict = hf_gpt_j.state_dict()
v = hf_gpt_j_state_dict.get('transformer.wte.weight')
xtrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
n_layer = hf_gpt_j.config.n_layer
for layer_idx in range(n_layer):
prefix = "transformer.h." + str(layer_idx) + "."
for idx, hf_attr in enumerate(hf_model_gptj_block_names):
v = hf_gpt_j_state_dict.get(prefix + hf_attr)
layer = attrgetter(xtrt_llm_model_gptj_block_names[idx])(
xtrt_llm_gpt_j.layers[layer_idx])
if idx == 2 and scaling_factors:
xtrt_llm_gpt_j.layers[
layer_idx].mlp.fc.activation_scaling_factor.value = np.array(
[scaling_factors['fc_act'][layer_idx]],
dtype=np.float32)
xtrt_llm_gpt_j.layers[
layer_idx].mlp.fc.weights_scaling_factor.value = np.array(
[scaling_factors['fc_weights'][layer_idx]],
dtype=np.float32)
elif idx == 4 and scaling_factors:
xtrt_llm_gpt_j.layers[
layer_idx].mlp.proj.activation_scaling_factor.value = np.array(
[scaling_factors['proj_act'][layer_idx]],
dtype=np.float32)
xtrt_llm_gpt_j.layers[
layer_idx].mlp.proj.weights_scaling_factor.value = np.array(
[scaling_factors['proj_weights'][layer_idx]],
dtype=np.float32)
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
# Attention QKV Linear
# concatenate the Q, K, V layers weights.
q_weights = hf_gpt_j_state_dict.get(prefix + "attn.q_proj.weight")
k_weights = hf_gpt_j_state_dict.get(prefix + "attn.k_proj.weight")
v_weights = hf_gpt_j_state_dict.get(prefix + "attn.v_proj.weight")
qkv_weights = torch.cat((q_weights, k_weights, v_weights))
layer = attrgetter("attention.qkv.weight")(
xtrt_llm_gpt_j.layers[layer_idx])
setattr(layer, "value", qkv_weights.to(torch_dtype).cpu().numpy())
if scaling_factors:
xtrt_llm_gpt_j.layers[
layer_idx].attention.qkv.activation_scaling_factor.value = np.array(
[scaling_factors['qkv_act'][layer_idx]], dtype=np.float32)
xtrt_llm_gpt_j.layers[
layer_idx].attention.qkv.weights_scaling_factor.value = np.array(
[scaling_factors['qkv_weights'][layer_idx]],
dtype=np.float32)
if quant_mode.has_fp8_kv_cache():
if scaling_factors:
xtrt_llm_gpt_j.layers[
layer_idx].attention.kv_orig_quant_scale.value = np.array(
[scaling_factors['qkv_output'][layer_idx]],
dtype=np.float32)
xtrt_llm_gpt_j.layers[
layer_idx].attention.kv_quant_orig_scale.value = np.array(
[1.0 / scaling_factors['qkv_output'][layer_idx]],
dtype=np.float32)
# Attention Dense (out_proj) Linear
v = hf_gpt_j_state_dict.get(prefix + "attn.out_proj.weight")
layer = attrgetter("attention.dense.weight")(
xtrt_llm_gpt_j.layers[layer_idx])
setattr(layer, "value", v.to(torch_dtype).cpu().numpy())
if scaling_factors:
xtrt_llm_gpt_j.layers[
layer_idx].attention.dense.activation_scaling_factor.value = np.array(
[scaling_factors['dense_act'][layer_idx]], dtype=np.float32)
xtrt_llm_gpt_j.layers[
layer_idx].attention.dense.weights_scaling_factor.value = np.array(
[scaling_factors['dense_weights'][layer_idx]],
dtype=np.float32)
v = hf_gpt_j_state_dict.get('transformer.ln_f.weight')
xtrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
v = hf_gpt_j_state_dict.get('transformer.ln_f.bias')
xtrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
v = hf_gpt_j_state_dict.get('lm_head.weight')
xtrt_llm_gpt_j.lm_head.weight.value = v.to(torch_dtype).cpu().numpy()
v = hf_gpt_j_state_dict.get('lm_head.bias')
xtrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
def AWQ_quantize_pack_preprocess(weight, scale, group_size, packer,
preprocessor):
scale = scale.repeat_interleave(group_size, dim=0)
weight = weight / scale
weight = torch.round(weight).char()
weight = torch.where(weight > 7, 7, weight)
qweight_int8 = torch.where(weight < -8, -8, weight)
int4_weight = packer(qweight_int8.cpu())
int4_weight = preprocessor(int4_weight, torch.quint4x2)
return int4_weight.view(torch.float32).cpu().numpy()
def process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
preprocessor, torch_dtype):
weight = awq_gpt_j[mPrefix + ".weight"].T.contiguous()
[k, n] = weight.shape
amax = awq_gpt_j[mPrefix + ".weight_quantizer._amax"].reshape(
(n, int(k / group_size))).T.contiguous()
pre_quant_scale = awq_gpt_j[mPrefix +
".input_quantizer._pre_quant_scale"].reshape(
(1, k))
scale = amax / 8.0
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale, group_size,
packer, preprocessor)
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy()
def deSmooth(weight, pre_quant_scale):
[k, n] = weight.shape
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
0).contiguous()
weight = weight * pre_quant_scale
return weight
def reSmooth(weight, pre_quant_scale):
[k, n] = weight.shape
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
0).contiguous()
weight = weight / pre_quant_scale
return weight
def get_scale(weight, group_size):
weight = weight.T.contiguous()
[n, k] = weight.shape
weight = weight.reshape(n, int(k / group_size), group_size)
weight = torch.abs(weight.reshape(-1, group_size))
amax, idx = weight.max(1)
amax = amax.reshape(n, int(k / group_size)).T.contiguous()
return amax / 8
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale,
group_size):
weight = deSmooth(weight, pre_quant_scale)
weight = reSmooth(weight, avg_pre_quant_scale)
scale = get_scale(weight, group_size)
return weight, scale
def process_and_assign_qkv_weight(awq_gpt_j, prefix, mOp, group_size, packer,
preprocessor, torch_dtype):
q_weight = awq_gpt_j[prefix + "attn.q_proj.weight"].T.contiguous()
k_weight = awq_gpt_j[prefix + "attn.k_proj.weight"].T.contiguous()
v_weight = awq_gpt_j[prefix + "attn.v_proj.weight"].T.contiguous()
[k, n] = q_weight.shape
q_pre_quant_scale = awq_gpt_j[
prefix + "attn.q_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
k_pre_quant_scale = awq_gpt_j[
prefix + "attn.k_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
v_pre_quant_scale = awq_gpt_j[
prefix + "attn.v_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
v_pre_quant_scale) / 3.0
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
qkv_pre_quant_scale, group_size)
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
qkv_pre_quant_scale, group_size)
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
qkv_pre_quant_scale, group_size)
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
torch_dtype).cpu().numpy()
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale,
group_size, packer,
preprocessor)
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
def load_from_awq_gpt_j(xtrt_llm_gpt_j: GPTJForCausalLM,
awq_gpt_j,
config,
dtype="float16",
group_size=128):
awq_gptj_block_names = [
"ln_1.weight",
"ln_1.bias",
"mlp.fc_in.bias",
"mlp.fc_out.bias",
]
xtrt_llm_model_gptj_block_names = [
"input_layernorm.weight",
"input_layernorm.bias",
"mlp.fc.bias",
"mlp.proj.bias",
]
getattr(xtrt_llm_gpt_j, 'quant_mode', QuantMode(0))
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
xtrt_llm.logger.info('Loading weights from AWQ GPT-J...')
tik = time.time()
torch_dtype = str_dtype_to_torch(dtype)
#check if we need to pad vocab
v = awq_gpt_j.get('transformer.wte.weight')
[vocab_size, k] = v.shape
pad_vocab = False
pad_vocab_size = vocab_size
if vocab_size % 64 != 0:
pad_vocab = True
pad_vocab_size = int((vocab_size + 63) / 64) * 64
if pad_vocab:
new_v = torch.zeros([pad_vocab_size, k])
new_v[:vocab_size, :] = v
v = new_v
xtrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
n_layer = config["n_layer"]
for layer_idx in range(n_layer):
prefix = "transformer.h." + str(layer_idx) + "."
xtrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
for idx, awq_attr in enumerate(awq_gptj_block_names):
v = awq_gpt_j[prefix + awq_attr]
layer = attrgetter(xtrt_llm_model_gptj_block_names[idx])(
xtrt_llm_gpt_j.layers[layer_idx])
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
# Attention QKV Linear
# concatenate the Q, K, V layers weights.
process_and_assign_qkv_weight(
awq_gpt_j, prefix,
xtrt_llm_gpt_j.layers[layer_idx].attention.qkv, group_size,
packer, preprocessor, torch_dtype)
# Attention Dense (out_proj) Linear
mPrefix = prefix + "attn.out_proj"
mOp = xtrt_llm_gpt_j.layers[layer_idx].attention.dense
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
preprocessor, torch_dtype)
# MLP Dense (mlp.fc) Linear
mPrefix = prefix + "mlp.fc_in"
mOp = xtrt_llm_gpt_j.layers[layer_idx].mlp.fc
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
preprocessor, torch_dtype)
# MLP Desne (mlp.proj) Linear
mPrefix = prefix + "mlp.fc_out"
mOp = xtrt_llm_gpt_j.layers[layer_idx].mlp.proj
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
preprocessor, torch_dtype)
v = awq_gpt_j['transformer.ln_f.weight']
xtrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
v = awq_gpt_j['transformer.ln_f.bias']
xtrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
#lm_head
if pad_vocab:
weight = awq_gpt_j['lm_head.weight']
[vocab_size, k] = weight.shape
new_weight = torch.zeros([pad_vocab_size, k])
new_weight[:vocab_size, :] = weight
new_weight = new_weight.T.contiguous()
amax = awq_gpt_j['lm_head.weight_quantizer._amax'].reshape(
[vocab_size, int(k / group_size)])
new_amax = torch.ones([pad_vocab_size, int(k / group_size)])
new_amax[:vocab_size, :] = amax
new_amax = new_amax.T.contiguous()
new_scale = new_amax / 8
xtrt_llm_gpt_j.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
new_weight, new_scale, group_size, packer, preprocessor)
xtrt_llm_gpt_j.lm_head.scale.value = new_scale.to(
torch_dtype).cpu().numpy()
xtrt_llm_gpt_j.lm_head.pre_quant_scale.value = awq_gpt_j[
'lm_head.input_quantizer._pre_quant_scale'].to(
torch_dtype).cpu().numpy()
bias = awq_gpt_j['lm_head.bias']
new_bias = torch.zeros([pad_vocab_size])
new_bias[:vocab_size] = bias
xtrt_llm_gpt_j.lm_head.bias.value = new_bias.to(
torch_dtype).cpu().numpy()
else:
mPrefix = "lm_head"
mOp = xtrt_llm_gpt_j.lm_head
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
preprocessor, torch_dtype)
v = awq_gpt_j['lm_head.bias']
xtrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')