add pkgs
This commit is contained in:
5
examples/gptj/.gitignore
vendored
Normal file
5
examples/gptj/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
__pycache__/
|
||||
gptj_model/
|
||||
*.log
|
||||
*.txt
|
||||
*.json
|
||||
77
examples/gptj/README.md
Normal file
77
examples/gptj/README.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# GPT-J
|
||||
|
||||
This document explains how to build the [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b) model using XTRT-LLM and run on a single XPU.
|
||||
|
||||
## Overview
|
||||
|
||||
The XTRT-LLM GPT-J example
|
||||
code is located in [`examples/gptj`](./). There are several main files in that folder:
|
||||
|
||||
* [`build.py`](./build.py) to build the [XTRT] engine(s) needed to run the GPT-J model,
|
||||
* [`run.py`](./run.py) to run the inference on an input text,
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Download weights from HuggingFace (HF) Transformers
|
||||
|
||||
```bash
|
||||
# 1. Weights & config
|
||||
git clone https://huggingface.co/EleutherAI/gpt-j-6b ./downloads/gptj-6b
|
||||
pushd ./downloads/gptj-6b && \
|
||||
rm -f pytorch_model.bin && \
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
|
||||
popd
|
||||
|
||||
# 2. Vocab and merge table
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt
|
||||
```
|
||||
|
||||
### 2. Build XTRT engine(s)
|
||||
|
||||
XTRT-LLM builds XTRT engine(s) using a HF checkpoint. If no checkpoint directory is specified, XTRT-LLM will build engine(s) using
|
||||
dummy weights.
|
||||
|
||||
Examples of build invocations:
|
||||
|
||||
```bash
|
||||
# Build a float16 engine using HF weights.
|
||||
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
|
||||
|
||||
python3 build.py --dtype=float16 \
|
||||
--log_level=verbose \
|
||||
--enable_context_fmha \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_gemm_plugin float16 \
|
||||
--max_batch_size=32 \
|
||||
--max_input_len=1919 \
|
||||
--max_output_len=128 \
|
||||
--output_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
|
||||
--model_dir=./downloads/gptj-6b 2>&1 | tee build.log
|
||||
|
||||
# Build a float16 engine using dummy weights, useful for performance tests.
|
||||
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
|
||||
|
||||
python3 build.py --dtype=float16 \
|
||||
--log_level=verbose \
|
||||
--enable_context_fmha \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_gemm_plugin float16 \
|
||||
--max_batch_size=32 \
|
||||
--max_input_len=1919 \
|
||||
--max_output_len=128 \
|
||||
--output_dir=./downloads/gptj-6b/trt_engines/gptj_engine_dummy_weights 2>&1 | tee build.log
|
||||
```
|
||||
|
||||
### 3. Run
|
||||
|
||||
To run a XTRT-LLM GPT-J model:
|
||||
|
||||
```bash
|
||||
python3 run.py --max_output_len=50 \
|
||||
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
|
||||
--hf_model_location=./downloads/gptj-6b
|
||||
```
|
||||
76
examples/gptj/README_CN.md
Normal file
76
examples/gptj/README_CN.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# GPT-J
|
||||
|
||||
本文档介绍了如何使用昆仑芯XTRT-LLM在单XPU上构建和运行[GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b)模型。
|
||||
|
||||
## 概述
|
||||
|
||||
XTRT-LLM GPT-J 示例代码位于 [`examples/gptj`](./)。 此文件夹中有以下几个主要文件:
|
||||
|
||||
* [`build.py`](./build.py) 构建运行GPT-J模型所需的XTRT引擎
|
||||
* [`run.py`](./run.py) 基于输入的文字进行推理
|
||||
|
||||
## 支持的矩阵
|
||||
|
||||
* FP16
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 1.从HuggingFace(HF) Transformers下载权重
|
||||
|
||||
```bash
|
||||
# 1. Weights & config
|
||||
git clone https://huggingface.co/EleutherAI/gpt-j-6b ./downloads/gptj-6b
|
||||
pushd ./downloads/gptj-6b && \
|
||||
rm -f pytorch_model.bin && \
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
|
||||
popd
|
||||
|
||||
# 2. Vocab and merge table
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
|
||||
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt
|
||||
```
|
||||
|
||||
### 2. 构建XTRT引擎
|
||||
|
||||
XTRT-LLM从HF checkpoint构建XTRT引擎。如果未指定checkpoint目录,XTRT-LLM将使用伪权重构建引擎。
|
||||
|
||||
构建调用示例:
|
||||
|
||||
```bash
|
||||
# Build a float16 engine using HF weights.
|
||||
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
|
||||
|
||||
python3 build.py --dtype=float16 \
|
||||
--log_level=verbose \
|
||||
--enable_context_fmha \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_gemm_plugin float16 \
|
||||
--max_batch_size=32 \
|
||||
--max_input_len=1919 \
|
||||
--max_output_len=128 \
|
||||
--output_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
|
||||
--model_dir=./downloads/gptj-6b 2>&1 | tee build.log
|
||||
|
||||
# Build a float16 engine using dummy weights, useful for performance tests.
|
||||
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.
|
||||
|
||||
python3 build.py --dtype=float16 \
|
||||
--log_level=verbose \
|
||||
--enable_context_fmha \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_gemm_plugin float16 \
|
||||
--max_batch_size=32 \
|
||||
--max_input_len=1919 \
|
||||
--max_output_len=128 \
|
||||
--output_dir=./downloads/gptj-6b/trt_engines/gptj_engine_dummy_weights 2>&1 | tee build.log
|
||||
```
|
||||
|
||||
### 3. 运行
|
||||
|
||||
要运行XTRT-LLM GPT-J模型,请执行以下操作:
|
||||
|
||||
```bash
|
||||
python3 run.py --max_output_len=50 \
|
||||
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
|
||||
--hf_model_location=./downloads/gptj-6b
|
||||
```
|
||||
489
examples/gptj/build.py
Normal file
489
examples/gptj/build.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import tvm.tensorrt as trt
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import AutoModelForCausalLM
|
||||
from weight import get_scaling_factors, load_from_awq_gpt_j, load_from_hf_gpt_j
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm.builder import Builder
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import (weight_only_groupwise_quantize,
|
||||
weight_only_quantize)
|
||||
from xtrt_llm.network import net_guard
|
||||
from xtrt_llm.plugin.plugin import ContextFMHAType
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
MODEL_NAME = "gptj"
|
||||
hf_gpt = None
|
||||
awq_gptj_config = None
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, rank):
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
engine.serialize(path)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def parse_arguments(args):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now')
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The path to HF GPT-J model / checkpoints to read weights from')
|
||||
parser.add_argument('--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
default='model.cache',
|
||||
help=
|
||||
'The path of to read timing cache from, will be ignored if the file does not exist'
|
||||
)
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--vocab_size', type=int, default=50401)
|
||||
parser.add_argument('--n_layer', type=int, default=28)
|
||||
parser.add_argument('--n_positions', type=int, default=2048)
|
||||
parser.add_argument('--n_embd', type=int, default=4096)
|
||||
parser.add_argument('--n_head', type=int, default=16)
|
||||
parser.add_argument('--hidden_act', type=str, default='gelu')
|
||||
parser.add_argument('--rotary_dim', type=int, default=64)
|
||||
parser.add_argument('--max_batch_size', type=int, default=256)
|
||||
parser.add_argument('--max_input_len', type=int, default=200)
|
||||
parser.add_argument('--max_output_len', type=int, default=200)
|
||||
parser.add_argument('--max_beam_width', type=int, default=1)
|
||||
parser.add_argument('--use_gpt_attention_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--use_gemm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--use_weight_only_quant_matmul_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16'])
|
||||
parser.add_argument('--use_layernorm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='gpt_outputs',
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
parser.add_argument('--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_fp8', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--quantized_fp8_model_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path of a quantized model checkpoint that in .npz format')
|
||||
parser.add_argument(
|
||||
'--fp8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_inflight_batching',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.")
|
||||
parser.add_argument(
|
||||
'--enable_two_optimization_profiles',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
"Enables two optimization profiles during engine build, for context and generate phases. By default (and for inflight batching too), only 1 opt profile."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--paged_kv_cache',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
|
||||
)
|
||||
parser.add_argument('--tokens_per_block',
|
||||
type=int,
|
||||
default=64,
|
||||
help='Number of tokens per block in paged KV cache')
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine')
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The falg is built for GPTQ/AWQ quantization.')
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
if not args.remove_input_padding:
|
||||
if args.use_gpt_attention_plugin:
|
||||
logger.warning(
|
||||
f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
|
||||
)
|
||||
|
||||
if args.model_dir is not None:
|
||||
global hf_gpt
|
||||
if args.use_weight_only and args.weight_only_precision == 'int4' and args.per_group:
|
||||
logger.info(f'Loading AWQ GPTJ model from {args.model_dir}...')
|
||||
global awq_gptj_config
|
||||
with open(args.model_dir + "/config.json",
|
||||
encoding='utf-8') as config_file:
|
||||
awq_gptj_config = json.load(config_file)
|
||||
args.n_embd = awq_gptj_config['n_embd']
|
||||
args.n_head = awq_gptj_config['n_head']
|
||||
args.n_layer = awq_gptj_config['n_layer']
|
||||
args.n_positions = awq_gptj_config['n_positions']
|
||||
args.vocab_size = awq_gptj_config['vocab_size']
|
||||
if args.vocab_size % 64 != 0:
|
||||
args.vocab_size = int(
|
||||
(awq_gptj_config['vocab_size'] + 63) / 64) * 64
|
||||
print(
|
||||
"vocab_size is {}, to use awq we pad it to {}.".format(
|
||||
awq_gptj_config['vocab_size'], args.vocab_size))
|
||||
hf_gpt = torch.load(args.model_dir + "/gptj_quantized.pth")
|
||||
else:
|
||||
logger.info(f'Loading HF GPTJ model from {args.model_dir}...')
|
||||
hf_gpt = AutoModelForCausalLM.from_pretrained(args.model_dir)
|
||||
args.n_embd = hf_gpt.config.n_embd
|
||||
args.n_head = hf_gpt.config.n_head
|
||||
args.n_layer = hf_gpt.config.n_layer
|
||||
args.n_positions = hf_gpt.config.n_positions
|
||||
args.vocab_size = hf_gpt.config.vocab_size
|
||||
|
||||
assert not (args.use_weight_only and args.weight_only_precision
|
||||
== 'int8'), "Not support int8 weight only."
|
||||
|
||||
assert not (args.use_weight_only and args.weight_only_precision == 'int4'
|
||||
and args.per_group
|
||||
== False), "We only support AWQ for int4 weight only."
|
||||
|
||||
if args.use_weight_only:
|
||||
args.quant_mode = QuantMode.use_weight_only(
|
||||
args.weight_only_precision == 'int4')
|
||||
else:
|
||||
args.quant_mode = QuantMode(0)
|
||||
|
||||
if args.fp8_kv_cache:
|
||||
assert (
|
||||
args.use_gpt_attention_plugin
|
||||
), "You have to use GPT attention plugin when fp8 KV cache is set"
|
||||
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
if args.enable_fp8:
|
||||
args.quant_mode = args.quant_mode.set_fp8_qdq()
|
||||
|
||||
if args.use_inflight_batching:
|
||||
if not args.use_gpt_attention_plugin:
|
||||
args.use_gpt_attention_plugin = 'float16'
|
||||
logger.info(
|
||||
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
|
||||
)
|
||||
if not args.remove_input_padding:
|
||||
args.remove_input_padding = True
|
||||
logger.info(
|
||||
"Using remove input padding for inflight batching mode.")
|
||||
if not args.paged_kv_cache:
|
||||
args.paged_kv_cache = True
|
||||
logger.info("Using paged KV cache for inflight batching mode.")
|
||||
|
||||
if args.max_num_tokens is not None:
|
||||
assert args.enable_context_fmha
|
||||
|
||||
if args.remove_input_padding or args.use_inflight_batching or args.paged_kv_cache:
|
||||
assert (
|
||||
not args.enable_two_optimization_profiles
|
||||
), "Only 1 opt profile supported for inflight batching and paged kv cache."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_rank_engine(builder: Builder,
|
||||
builder_config: xtrt_llm.builder.BuilderConfig,
|
||||
engine_name, rank, args):
|
||||
'''
|
||||
@brief: Build the engine on the given rank.
|
||||
@param rank: The rank to build the engine.
|
||||
@param args: The cmd line arguments.
|
||||
@return: The built engine.
|
||||
'''
|
||||
kv_dtype = trt.float16 if args.dtype == 'float16' else trt.float32
|
||||
|
||||
# Initialize Module
|
||||
xtrt_llm_gpt = xtrt_llm.models.GPTJForCausalLM(
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
hidden_size=args.n_embd,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
rotary_dim=args.rotary_dim,
|
||||
dtype=kv_dtype,
|
||||
logits_dtype=args.logits_dtype,
|
||||
mapping=Mapping(world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.world_size), # TP only
|
||||
quant_mode=args.quant_mode)
|
||||
if args.use_weight_only_quant_matmul_plugin:
|
||||
xtrt_llm_gpt = weight_only_quantize(xtrt_llm_gpt)
|
||||
if args.use_weight_only and args.weight_only_precision == 'int4':
|
||||
if args.per_group:
|
||||
xtrt_llm_gpt = weight_only_groupwise_quantize(
|
||||
model=xtrt_llm_gpt,
|
||||
quant_mode=QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True),
|
||||
group_size=128,
|
||||
zero=False,
|
||||
pre_quant_scale=True,
|
||||
exclude_modules=[],
|
||||
)
|
||||
if args.model_dir is not None:
|
||||
assert hf_gpt is not None, f'Could not load weights from hf_gpt model as it is not loaded yet.'
|
||||
if args.enable_fp8:
|
||||
gptj_scaling_factors = get_scaling_factors(
|
||||
args.quantized_fp8_model_path, args.n_layer, args.quant_mode)
|
||||
else:
|
||||
gptj_scaling_factors = None
|
||||
if args.use_weight_only and args.weight_only_precision == 'int4' and args.per_group:
|
||||
load_from_awq_gpt_j(xtrt_llm_gpt,
|
||||
awq_gpt_j=hf_gpt,
|
||||
config=awq_gptj_config,
|
||||
dtype=args.dtype)
|
||||
else:
|
||||
load_from_hf_gpt_j(xtrt_llm_gpt,
|
||||
hf_gpt,
|
||||
args.dtype,
|
||||
scaling_factors=gptj_scaling_factors)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
if args.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
if args.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
dtype=args.use_layernorm_plugin)
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.use_weight_only_quant_matmul_plugin:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.use_weight_only_quant_matmul_plugin)
|
||||
if args.use_weight_only:
|
||||
if args.per_group:
|
||||
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype)
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
if args.paged_kv_cache:
|
||||
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(xtrt_llm_gpt.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = xtrt_llm_gpt.prepare_inputs(
|
||||
args.max_batch_size,
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
True,
|
||||
args.max_beam_width,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
enable_two_optimization_profiles=args.
|
||||
enable_two_optimization_profiles)
|
||||
xtrt_llm_gpt(*inputs)
|
||||
|
||||
# xtrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
engine = None
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config, compiler="gr")
|
||||
if rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
# torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
# when doing serializing build, all ranks share one engine
|
||||
builder = Builder()
|
||||
|
||||
cache = None
|
||||
for cur_rank in range(args.world_size):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
|
||||
builder_config = builder.create_builder_config(
|
||||
name=MODEL_NAME,
|
||||
precision=args.dtype,
|
||||
timing_cache=args.timing_cache if cache is None else cache,
|
||||
tensor_parallel=args.world_size, # TP only
|
||||
parallel_build=args.parallel_build,
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
hidden_size=args.n_embd,
|
||||
inter_size=args.n_embd * 4,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
fp8=args.enable_fp8,
|
||||
quant_mode=args.quant_mode,
|
||||
strongly_typed=args.strongly_typed)
|
||||
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
|
||||
cur_rank)
|
||||
engine = build_rank_engine(builder, builder_config, engine_name,
|
||||
cur_rank, args)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
# if cur_rank == 0:
|
||||
# # Use in-memory timing cache for multiple builder passes.
|
||||
# if not args.parallel_build:
|
||||
# cache = builder_config.xtrt_builder_config.get_timing_cache()
|
||||
|
||||
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
|
||||
|
||||
# if rank == 0:
|
||||
# ok = builder.save_timing_cache(
|
||||
# builder_config, os.path.join(args.output_dir, "model.cache"))
|
||||
# assert ok, "Failed to save timing cache."
|
||||
|
||||
|
||||
def run_build(args=None):
|
||||
args = parse_arguments(args)
|
||||
tik = time.time()
|
||||
if args.parallel_build and args.world_size > 1 and \
|
||||
torch.cuda.device_count() >= args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
logger.info('Serially build TensorRT engines.')
|
||||
build(0, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Total time of building all {args.world_size} engines: {t}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_build()
|
||||
137
examples/gptj/quantize.py
Normal file
137
examples/gptj/quantize.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from xtrt_llm._utils import str_dtype_to_torch
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.models.quantized.ammo import quantize_and_export
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# NOTE truncate dataset to n_positions for RoPE in GPT-J
|
||||
batch_encoded = tokenizer.batch_encode_plus(
|
||||
dataset,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=block_size,
|
||||
)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
batch_encoded = batch_encoded.cuda()
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, **kwargs):
|
||||
logger.info(f"Loading tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path,
|
||||
padding_side="left",
|
||||
**kwargs)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="float16"):
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
ckpt_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model.eval()
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
return model
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory of a HF model checkpoint")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument("--qformat",
|
||||
type=str,
|
||||
choices=['fp8'],
|
||||
default='fp8',
|
||||
help='Quantization format.')
|
||||
parser.add_argument("--calib_size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Number of samples for calibration.")
|
||||
parser.add_argument("--export_path", default="exported_model")
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
args = get_args()
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
tokenizer = get_tokenizer(args.model_dir)
|
||||
model = get_model(args.model_dir, args.dtype)
|
||||
|
||||
calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
|
||||
calib_size=args.calib_size)
|
||||
model = quantize_and_export(model,
|
||||
qformat=args.qformat,
|
||||
calib_dataloader=calib_dataloader,
|
||||
export_path=args.export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
284
examples/gptj/run.py
Normal file
284
examples/gptj/run.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils import token_encoder
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
from xtrt_llm.runtime import ModelConfig, SamplingConfig
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
# GPT3 Related variables
|
||||
# Reference : https://github.com/NVIDIA/FasterTransformer/blob/main/sample/pytorch/gpt_sample.py
|
||||
MERGES_FILE = "merges.txt"
|
||||
VOCAB_FILE = "vocab.json"
|
||||
|
||||
PAD_ID = 50256
|
||||
START_ID = 50256
|
||||
END_ID = 50256
|
||||
|
||||
|
||||
def read_config(config_path: Path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
remove_input_padding = config['plugin_config']['remove_input_padding']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == xtrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
|
||||
num_heads = config['builder_config']['num_heads'] // world_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // world_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
quant_mode = QuantMode(config['builder_config']['quant_mode'])
|
||||
paged_kv_cache = config['plugin_config']['paged_kv_cache']
|
||||
tokens_per_block = config['plugin_config']['tokens_per_block']
|
||||
dtype = config['builder_config']['precision']
|
||||
|
||||
model_config = ModelConfig(num_heads=num_heads,
|
||||
num_kv_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=remove_input_padding,
|
||||
paged_kv_cache=paged_kv_cache,
|
||||
tokens_per_block=tokens_per_block,
|
||||
quant_mode=quant_mode,
|
||||
dtype=dtype)
|
||||
|
||||
max_input_len = config['builder_config']['max_input_len']
|
||||
|
||||
return model_config, world_size, dtype, max_input_len
|
||||
|
||||
|
||||
def parse_input(input_text: str, input_file: str, tokenizer, pad_id: int,
|
||||
remove_input_padding: bool):
|
||||
input_tokens = []
|
||||
if input_file is None:
|
||||
input_tokens.append(tokenizer.encode(input_text))
|
||||
else:
|
||||
if input_file.endswith('.csv'):
|
||||
with open(input_file, 'r') as csv_file:
|
||||
csv_reader = csv.reader(csv_file, delimiter=',')
|
||||
for line in csv_reader:
|
||||
input_tokens.append(np.array(line, dtype='int32'))
|
||||
elif input_file.endswith('.npy'):
|
||||
inputs = np.load(input_file)
|
||||
for row in inputs:
|
||||
row = row[row != pad_id]
|
||||
input_tokens.append(row)
|
||||
else:
|
||||
print('Input file format not supported.')
|
||||
raise SystemExit
|
||||
|
||||
input_ids = None
|
||||
input_lengths = torch.tensor([len(x) for x in input_tokens],
|
||||
dtype=torch.int32).cuda()
|
||||
if remove_input_padding:
|
||||
input_ids = np.concatenate(input_tokens)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32,
|
||||
device='cuda').unsqueeze(0)
|
||||
else:
|
||||
input_ids = torch.nested.to_padded_tensor(
|
||||
torch.nested.nested_tensor(input_tokens, dtype=torch.int32),
|
||||
pad_id).cuda()
|
||||
|
||||
return input_ids, input_lengths
|
||||
|
||||
|
||||
def print_output(output_ids, cum_log_probs, input_lengths, sequence_lengths,
|
||||
tokenizer, output_csv, output_npy):
|
||||
|
||||
num_beams = output_ids.size(1)
|
||||
if output_csv is None and output_npy is None:
|
||||
for b in range(input_lengths.size(0)):
|
||||
inputs = output_ids[b][0][:input_lengths[b]].tolist()
|
||||
input_text = tokenizer.decode(inputs)
|
||||
print(f'Input idx: {b}')
|
||||
print(f'Input: \"{input_text}\"')
|
||||
for beam in range(num_beams):
|
||||
output_begin = input_lengths[b]
|
||||
output_end = sequence_lengths[b][beam]
|
||||
outputs = output_ids[b][beam][output_begin:output_end].tolist()
|
||||
output_text = tokenizer.decode(outputs)
|
||||
if num_beams > 1:
|
||||
cum_log_prob = cum_log_probs[b][beam]
|
||||
print(f'Output idx: {b}, beam {beam} (cum_log_prob: {cum_log_prob})')
|
||||
print(f'Output: \"{output_text}\"')
|
||||
else:
|
||||
print(f'Output idx:{b}')
|
||||
print(f'Output: \"{output_text}\"')
|
||||
|
||||
output_ids = output_ids.reshape((-1, output_ids.size(2)))
|
||||
|
||||
if output_csv is not None:
|
||||
output_file = Path(output_csv)
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
outputs = output_ids.tolist()
|
||||
with open(output_file, 'w') as csv_file:
|
||||
writer = csv.writer(csv_file, delimiter=',')
|
||||
writer.writerows(outputs)
|
||||
|
||||
if output_npy is not None:
|
||||
output_file = Path(output_npy)
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
outputs = np.array(output_ids.cpu().contiguous(), dtype='int32')
|
||||
np.save(output_file, outputs)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='gpt_outputs')
|
||||
parser.add_argument('--num_beams', type=int, default=1)
|
||||
parser.add_argument('--min_length', type=int, default=1)
|
||||
parser.add_argument('--input_text',
|
||||
type=str,
|
||||
default='Born in north-east France, Soyer trained as a')
|
||||
parser.add_argument(
|
||||
'--input_tokens',
|
||||
dest='input_file',
|
||||
type=str,
|
||||
help=
|
||||
'CSV or Numpy file containing tokenized input. Alternative to text input.',
|
||||
default=None)
|
||||
parser.add_argument('--output_csv',
|
||||
type=str,
|
||||
help='CSV file where the tokenized output is stored.',
|
||||
default=None)
|
||||
parser.add_argument('--output_npy',
|
||||
type=str,
|
||||
help='Numpy file where the tokenized output is stored.',
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
'--hf_model_location',
|
||||
type=str,
|
||||
default="gptj_model",
|
||||
help=
|
||||
'The hugging face model location stores the merges.txt and vocab.json to create tokenizer'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--performance_test_scale',
|
||||
type=str,
|
||||
help=
|
||||
"Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)",
|
||||
default="")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def generate(
|
||||
max_output_len: int,
|
||||
log_level: str = 'error',
|
||||
engine_dir: str = 'gpt_outputs',
|
||||
input_text: str = 'Born in north-east France, Soyer trained as a',
|
||||
input_file: str = None,
|
||||
output_csv: str = None,
|
||||
output_npy: str = None,
|
||||
hf_model_location: str = 'gptj',
|
||||
num_beams: int = 1,
|
||||
min_length: int = 1,
|
||||
performance_test_scale: str = "",
|
||||
):
|
||||
xtrt_llm.logger.set_level(log_level)
|
||||
|
||||
engine_dir = Path(engine_dir)
|
||||
config_path = engine_dir / 'config.json'
|
||||
model_config, world_size, dtype, max_input_len = read_config(config_path)
|
||||
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
runtime_mapping = xtrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
vocab_file = Path(hf_model_location) / VOCAB_FILE
|
||||
merges_file = Path(hf_model_location) / MERGES_FILE
|
||||
assert vocab_file.is_file(), f"{vocab_file} does not exist"
|
||||
assert merges_file.is_file(), f"{merges_file} does not exist"
|
||||
tokenizer = token_encoder.get_encoder(vocab_file, merges_file)
|
||||
|
||||
sampling_config = SamplingConfig(end_id=END_ID,
|
||||
pad_id=PAD_ID,
|
||||
num_beams=num_beams,
|
||||
min_length=min_length)
|
||||
|
||||
engine_name = get_engine_name('gptj', dtype, world_size, runtime_rank)
|
||||
# serialize_path = Path(engine_dir) / engine_name
|
||||
serialize_path = str(engine_dir) + "/" + engine_name
|
||||
# with open(serialize_path, 'rb') as f:
|
||||
# engine_buffer = f.read()
|
||||
decoder = xtrt_llm.runtime.GenerationSession(model_config,
|
||||
serialize_path,
|
||||
runtime_mapping,
|
||||
debug_mode=False,
|
||||
debug_tensors_to_save=None)
|
||||
|
||||
input_ids, input_lengths = parse_input(input_text, input_file, tokenizer,
|
||||
PAD_ID,
|
||||
model_config.remove_input_padding)
|
||||
|
||||
if performance_test_scale != "":
|
||||
performance_test_scale_list = performance_test_scale.split("E")
|
||||
for scale in performance_test_scale_list:
|
||||
xtrt_llm.logger.info(f"Running performance test with scale {scale}")
|
||||
bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")]
|
||||
_input_ids = torch.from_numpy(
|
||||
np.zeros((bs, seqlen)).astype("int32")).cuda()
|
||||
_input_lengths = torch.from_numpy(
|
||||
np.full((bs, ), seqlen).astype("int32")).cuda()
|
||||
_max_input_length = torch.max(_input_lengths).item()
|
||||
|
||||
decoder.setup(_input_lengths.size(0), _max_input_length,
|
||||
_max_output_len, num_beams)
|
||||
_output_gen_ids = decoder.decode(_input_ids,
|
||||
_input_lengths,
|
||||
sampling_config,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True)
|
||||
|
||||
max_input_length = torch.max(input_lengths).item()
|
||||
decoder.setup(input_lengths.size(0),
|
||||
max_input_length,
|
||||
max_output_len,
|
||||
beam_width=num_beams)
|
||||
|
||||
outputs = decoder.decode(input_ids,
|
||||
input_lengths,
|
||||
sampling_config,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True)
|
||||
output_ids = outputs['output_ids']
|
||||
sequence_lengths = outputs['sequence_lengths']
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cum_log_probs = decoder.cum_log_probs if num_beams > 1 else None
|
||||
|
||||
if runtime_rank == 0:
|
||||
print_output(output_ids, cum_log_probs, input_lengths, sequence_lengths,
|
||||
tokenizer, output_csv, output_npy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
generate(**vars(args))
|
||||
8
examples/gptj/run.sh
Normal file
8
examples/gptj/run.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
XMLIR_D_XPU_L3_SIZE=0 \
|
||||
python3 run.py \
|
||||
--engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
|
||||
--hf_model_location=./downloads/gptj-6b \
|
||||
--max_output_len=2048 \
|
||||
--performance_test_scale=1x512x512E1x1024x1024E1x2000x64E1x2048x2048E2x512x512E2x1024x1024E2x2000x64E2x2048x2048E4x512x512E\
|
||||
4x1024x1024E4x2000x64E4x2048x2048E8x512x512E8x1024x1024E8x2000x64 \
|
||||
--log_level=info
|
||||
409
examples/gptj/summarize.py
Normal file
409
examples/gptj/summarize.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import xtrt_llm
|
||||
import xtrt_llm.profiler as profiler
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
|
||||
def TRTGPTJ(args, config):
|
||||
dtype = config['builder_config']['precision']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == xtrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
|
||||
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
num_heads = config['builder_config']['num_heads'] // world_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // world_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
use_gpt_attention_plugin = bool(
|
||||
config['plugin_config']['gpt_attention_plugin'])
|
||||
remove_input_padding = config['plugin_config']['remove_input_padding']
|
||||
quant_mode = QuantMode(config['builder_config'].get('quant_mode', 0))
|
||||
paged_kv_cache = config['plugin_config']['paged_kv_cache']
|
||||
tokens_per_block = config['plugin_config']['tokens_per_block']
|
||||
|
||||
model_config = xtrt_llm.runtime.ModelConfig(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=remove_input_padding,
|
||||
paged_kv_cache=paged_kv_cache,
|
||||
tokens_per_block=tokens_per_block,
|
||||
quant_mode=quant_mode,
|
||||
dtype=dtype)
|
||||
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
runtime_mapping = xtrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
engine_name = get_engine_name('gptj', dtype, world_size, runtime_rank)
|
||||
serialize_path = os.path.join(args.engine_dir, engine_name)
|
||||
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
decoder = xtrt_llm.runtime.GenerationSession(model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping)
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def main(args):
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
||||
test_trt_llm = args.test_trt_llm
|
||||
model_dir = args.model_dir
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir,
|
||||
padding_side='left',
|
||||
model_max_length=2048,
|
||||
truncation=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset_cnn = load_dataset("ccdv/cnn_dailymail",
|
||||
'3.0.0',
|
||||
cache_dir=args.dataset_path)
|
||||
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
max_batch_size = args.batch_size
|
||||
|
||||
# runtime parameters
|
||||
# repetition_penalty = 1
|
||||
top_k = args.top_k
|
||||
output_len = args.output_len
|
||||
test_token_num = 923
|
||||
# top_p = 0.0
|
||||
# random_seed = 5
|
||||
temperature = 1
|
||||
num_beams = args.num_beams
|
||||
|
||||
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
|
||||
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
|
||||
|
||||
if test_trt_llm:
|
||||
xtrt_llm_gpt = TRTGPTJ(args, config)
|
||||
|
||||
if test_hf:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
||||
model.cuda()
|
||||
if args.data_type == 'fp16':
|
||||
model.half()
|
||||
|
||||
def summarize_xtrt_llm(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
line_encoded = []
|
||||
input_lengths = []
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
input_id = tokenizer.encode(line[i],
|
||||
return_tensors='pt').type(torch.int32)
|
||||
input_id = input_id[:, -test_token_num:]
|
||||
|
||||
line_encoded.append(input_id)
|
||||
input_lengths.append(input_id.shape[-1])
|
||||
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
max_length = max(input_lengths)
|
||||
if xtrt_llm_gpt.remove_input_padding:
|
||||
line_encoded = [
|
||||
torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded
|
||||
]
|
||||
else:
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
for i in range(batch_size):
|
||||
pad_size = max_length - input_lengths[i]
|
||||
|
||||
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
|
||||
line_encoded[i] = torch.cat(
|
||||
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
|
||||
axis=-1)
|
||||
|
||||
line_encoded = torch.cat(line_encoded, axis=0).cuda()
|
||||
input_lengths = torch.tensor(input_lengths,
|
||||
dtype=torch.int32).cuda()
|
||||
|
||||
sampling_config = xtrt_llm.runtime.SamplingConfig(
|
||||
end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams)
|
||||
|
||||
with torch.no_grad():
|
||||
xtrt_llm_gpt.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
|
||||
if xtrt_llm_gpt.remove_input_padding:
|
||||
output_ids = xtrt_llm_gpt.decode_batch(
|
||||
line_encoded, sampling_config)
|
||||
else:
|
||||
output_ids = xtrt_llm_gpt.decode(
|
||||
line_encoded,
|
||||
input_lengths,
|
||||
sampling_config,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Extract a list of tensors of shape beam_width x output_ids.
|
||||
output_beams_list, output_ids_list = [], []
|
||||
if xtrt_llm_gpt.mapping.is_first_pp_rank():
|
||||
output_beams_list = [
|
||||
tokenizer.batch_decode(output_ids[batch_idx, :,
|
||||
input_lengths[batch_idx]:],
|
||||
skip_special_tokens=True)
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
output_ids_list = [
|
||||
output_ids[batch_idx, :, input_lengths[batch_idx]:]
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
return output_beams_list, output_ids_list
|
||||
|
||||
def summarize_hf(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
if batch_size > 1:
|
||||
logger.warning(
|
||||
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
|
||||
)
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
line_encoded = tokenizer(line,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True)["input_ids"].type(torch.int64)
|
||||
|
||||
line_encoded = line_encoded[:, -test_token_num:]
|
||||
line_encoded = line_encoded.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(line_encoded,
|
||||
max_length=len(line_encoded[0]) +
|
||||
output_len,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_beams,
|
||||
early_stopping=True)
|
||||
|
||||
tokens_list = output[:, len(line_encoded[0]):].tolist()
|
||||
output = output.reshape([batch_size, num_beams, -1])
|
||||
output_lines_list = [
|
||||
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
|
||||
skip_special_tokens=True)
|
||||
for i in range(num_beams)
|
||||
]
|
||||
|
||||
return output_lines_list, tokens_list
|
||||
|
||||
if test_trt_llm:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_xtrt_llm(datapoint)
|
||||
if runtime_rank == 0:
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
logger.info("XTRT-LLM Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
|
||||
if test_hf:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_hf(datapoint)
|
||||
logger.info("---------------------------------------------------------")
|
||||
logger.info("HF Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info("---------------------------------------------------------")
|
||||
|
||||
xtrt_llm_result = [[] for _ in range(num_beams)]
|
||||
hf_result = [[] for _ in range(num_beams)]
|
||||
ite_count = 0
|
||||
data_point_idx = 0
|
||||
|
||||
# Support running the set with different order to verify correctness
|
||||
test_idx = list(
|
||||
range(min(len(dataset_cnn['test']), max_batch_size * args.max_ite)))
|
||||
random.seed(args.random_seed)
|
||||
random.shuffle(test_idx)
|
||||
while (data_point_idx < len(dataset_cnn['test'])) and (ite_count <
|
||||
args.max_ite):
|
||||
if runtime_rank == 0:
|
||||
logger.debug(
|
||||
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
|
||||
)
|
||||
datapoint = dataset_cnn['test'][test_idx[data_point_idx:(
|
||||
data_point_idx + max_batch_size)]]
|
||||
|
||||
if test_trt_llm:
|
||||
profiler.start('xtrt_llm')
|
||||
summary_xtrt_llm, tokens_xtrt_llm = summarize_xtrt_llm(
|
||||
datapoint)
|
||||
profiler.stop('xtrt_llm')
|
||||
|
||||
if test_hf:
|
||||
profiler.start('hf')
|
||||
summary_hf, tokens_hf = summarize_hf(datapoint)
|
||||
profiler.stop('hf')
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
for batch_idx in range(len(summary_xtrt_llm)):
|
||||
for beam_idx in range(num_beams):
|
||||
xtrt_llm_result[beam_idx].append(
|
||||
tuple([
|
||||
datapoint['id'][batch_idx],
|
||||
summary_xtrt_llm[batch_idx][beam_idx],
|
||||
datapoint['highlights'][batch_idx]
|
||||
]))
|
||||
if test_hf:
|
||||
for beam_idx in range(num_beams):
|
||||
for batch_idx in range(len(summary_hf[beam_idx])):
|
||||
hf_result[beam_idx].append(
|
||||
tuple([
|
||||
datapoint['id'][batch_idx],
|
||||
summary_hf[beam_idx][batch_idx],
|
||||
datapoint['highlights'][batch_idx]
|
||||
]))
|
||||
|
||||
logger.debug('-' * 100)
|
||||
logger.debug(f"Article : {datapoint['article']}")
|
||||
if test_trt_llm:
|
||||
logger.debug(f'XTRT-LLM Summary: {summary_xtrt_llm}')
|
||||
if test_hf:
|
||||
logger.debug(f'HF Summary: {summary_hf}')
|
||||
logger.debug(f"highlights : {datapoint['highlights']}")
|
||||
|
||||
data_point_idx += max_batch_size
|
||||
ite_count += 1
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'XTRT-LLM (total latency: {profiler.elapsed_time_in_sec("xtrt_llm")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
# Because 'rouge' uses sampling to compute the scores, the scores
|
||||
# would be different when the results are same with different order.
|
||||
# So, sorting them first to prevent this issue.
|
||||
metric_xtrt_llm = load_metric("rouge")
|
||||
metric_xtrt_llm.seed = 0
|
||||
beams_results = sorted(xtrt_llm_result[beam_idx])
|
||||
|
||||
for j in range(len(beams_results)):
|
||||
metric_xtrt_llm.add_batch(
|
||||
predictions=[beams_results[j][1]],
|
||||
references=[beams_results[j][2]])
|
||||
|
||||
logger.info(f"XTRT-LLM beam {beam_idx} result")
|
||||
computed_metrics_xtrt_llm = metric_xtrt_llm.compute()
|
||||
for key in computed_metrics_xtrt_llm.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_xtrt_llm[key].mid[2]*100}'
|
||||
)
|
||||
|
||||
if args.check_accuracy and beam_idx == 0:
|
||||
assert computed_metrics_xtrt_llm['rouge1'].mid[
|
||||
2] * 100 > args.xtrt_llm_rouge1_threshold
|
||||
if test_hf:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
metric_tensorrt_hf = load_metric("rouge")
|
||||
metric_tensorrt_hf.seed = 0
|
||||
beams_results = sorted(hf_result[beam_idx])
|
||||
|
||||
for j in range(len(beams_results)):
|
||||
metric_tensorrt_hf.add_batch(
|
||||
predictions=[beams_results[j][1]],
|
||||
references=[beams_results[j][2]])
|
||||
logger.info(f"HF beam {beam_idx} result")
|
||||
computed_metrics_hf = metric_tensorrt_hf.compute()
|
||||
for key in computed_metrics_hf.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_hf[key].mid[2]*100}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_dir', type=str, default='EleutherAI/gpt-j-6B')
|
||||
parser.add_argument('--test_hf', action='store_true')
|
||||
parser.add_argument('--test_trt_llm', action='store_true')
|
||||
parser.add_argument('--data_type',
|
||||
type=str,
|
||||
choices=['fp32', 'fp16'],
|
||||
default='fp32')
|
||||
parser.add_argument('--dataset_path', type=str, default='')
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--engine_dir', type=str, default='gptj_engine')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--output_len', type=int, default=100)
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--xtrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
default=15.0)
|
||||
parser.add_argument('--num_beams', type=int, default=1)
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
parser.add_argument('--random_seed', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
14
examples/gptj/utils/__init__.py
Normal file
14
examples/gptj/utils/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
178
examples/gptj/utils/token_encoder.py
Normal file
178
examples/gptj/utils/token_encoder.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Byte pair encoding utilities"""
|
||||
|
||||
# Modified MIT License
|
||||
|
||||
# Software Copyright (c) 2019 OpenAI
|
||||
|
||||
# We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
|
||||
# We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
# associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
# including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
# subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included
|
||||
# in all copies or substantial portions of the Software.
|
||||
# The above copyright notice and this permission notice need not be included
|
||||
# with content created by the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
# OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a significant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"),
|
||||
ord("~") + 1)) + list(range(
|
||||
ord("¡"),
|
||||
ord("¬") + 1)) + list(range(ord("®"),
|
||||
ord("ÿ") + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
|
||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
# Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i +
|
||||
1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c]
|
||||
for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def batch_decode(self, output):
|
||||
ret = []
|
||||
for tokens in output:
|
||||
ret.append(self.decode(tokens))
|
||||
return ret
|
||||
|
||||
|
||||
def get_encoder(vocab_file, bpe_file):
|
||||
with open(vocab_file, 'r', encoding="utf-8") as f:
|
||||
encoder = json.load(f)
|
||||
with open(bpe_file, 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [
|
||||
tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]
|
||||
]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
||||
455
examples/gptj/weight.py
Normal file
455
examples/gptj/weight.py
Normal file
@@ -0,0 +1,455 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from operator import attrgetter
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import xtrt_llm
|
||||
import xtrt_llm.logger as logger
|
||||
from xtrt_llm._utils import str_dtype_to_torch
|
||||
from xtrt_llm.models import GPTJForCausalLM
|
||||
from xtrt_llm.models.quantized.quant import get_dummy_quant_scales
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def get_scaling_factors(
|
||||
model_path: Union[str, Path],
|
||||
num_layers: int,
|
||||
quant_mode: Optional[QuantMode] = None,
|
||||
) -> Optional[Dict[str, List[int]]]:
|
||||
""" Get the scaling factors for GPT-J model
|
||||
|
||||
Returns a dictionary of scaling factors for the selected layers of the
|
||||
GPT-J model.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the quantized GPT-J model
|
||||
layers (list): List of layers to get the scaling factors for. If None,
|
||||
all layers are selected.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of scaling factors for the selected layers of the
|
||||
GPT-J model.
|
||||
|
||||
example:
|
||||
|
||||
{
|
||||
'qkv_act': qkv_act_scale,
|
||||
'qkv_weights': qkv_weights_scale,
|
||||
'qkv_output' : qkv_outputs_scale,
|
||||
'dense_act': dense_act_scale,
|
||||
'dense_weights': dense_weights_scale,
|
||||
'fc_act': fc_act_scale,
|
||||
'fc_weights': fc_weights_scale,
|
||||
'proj_act': proj_act_scale,
|
||||
'proj_weights': proj_weights_scale,
|
||||
}
|
||||
"""
|
||||
|
||||
if model_path is None:
|
||||
logger.warning(f"--quantized_fp8_model_path not specified. "
|
||||
f"Initialize quantization scales automatically.")
|
||||
return get_dummy_quant_scales(num_layers)
|
||||
weight_dict = np.load(model_path)
|
||||
|
||||
# yapf: disable
|
||||
scaling_factor = {
|
||||
'qkv_act': [],
|
||||
'qkv_weights': [],
|
||||
'qkv_output': [],
|
||||
'dense_act': [],
|
||||
'dense_weights': [],
|
||||
'fc_act': [],
|
||||
'fc_weights': [],
|
||||
'proj_act': [],
|
||||
'proj_weights': [],
|
||||
}
|
||||
|
||||
for layer in range(num_layers):
|
||||
scaling_factor['qkv_act'].append(max(
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
|
||||
))
|
||||
scaling_factor['qkv_weights'].append(max(
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
|
||||
))
|
||||
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
|
||||
# Not calibrarting KV cache.
|
||||
scaling_factor['qkv_output'].append(1.0)
|
||||
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
|
||||
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
|
||||
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
|
||||
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
|
||||
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
|
||||
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
|
||||
# yapf: enable
|
||||
for k, v in scaling_factor.items():
|
||||
assert len(v) == num_layers, \
|
||||
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
|
||||
|
||||
return scaling_factor
|
||||
|
||||
|
||||
def load_from_hf_gpt_j(xtrt_llm_gpt_j: GPTJForCausalLM,
|
||||
hf_gpt_j,
|
||||
dtype="float32",
|
||||
scaling_factors=None):
|
||||
|
||||
hf_model_gptj_block_names = [
|
||||
"ln_1.weight",
|
||||
"ln_1.bias",
|
||||
"mlp.fc_in.weight",
|
||||
"mlp.fc_in.bias",
|
||||
"mlp.fc_out.weight",
|
||||
"mlp.fc_out.bias",
|
||||
]
|
||||
|
||||
xtrt_llm_model_gptj_block_names = [
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"mlp.fc.weight",
|
||||
"mlp.fc.bias",
|
||||
"mlp.proj.weight",
|
||||
"mlp.proj.bias",
|
||||
]
|
||||
|
||||
quant_mode = getattr(xtrt_llm_gpt_j, 'quant_mode', QuantMode(0))
|
||||
|
||||
xtrt_llm.logger.info('Loading weights from HF GPT-J...')
|
||||
tik = time.time()
|
||||
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
hf_gpt_j_state_dict = hf_gpt_j.state_dict()
|
||||
|
||||
v = hf_gpt_j_state_dict.get('transformer.wte.weight')
|
||||
xtrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
n_layer = hf_gpt_j.config.n_layer
|
||||
|
||||
for layer_idx in range(n_layer):
|
||||
prefix = "transformer.h." + str(layer_idx) + "."
|
||||
for idx, hf_attr in enumerate(hf_model_gptj_block_names):
|
||||
v = hf_gpt_j_state_dict.get(prefix + hf_attr)
|
||||
layer = attrgetter(xtrt_llm_model_gptj_block_names[idx])(
|
||||
xtrt_llm_gpt_j.layers[layer_idx])
|
||||
if idx == 2 and scaling_factors:
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].mlp.fc.activation_scaling_factor.value = np.array(
|
||||
[scaling_factors['fc_act'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].mlp.fc.weights_scaling_factor.value = np.array(
|
||||
[scaling_factors['fc_weights'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
elif idx == 4 and scaling_factors:
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].mlp.proj.activation_scaling_factor.value = np.array(
|
||||
[scaling_factors['proj_act'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].mlp.proj.weights_scaling_factor.value = np.array(
|
||||
[scaling_factors['proj_weights'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
|
||||
|
||||
# Attention QKV Linear
|
||||
# concatenate the Q, K, V layers weights.
|
||||
q_weights = hf_gpt_j_state_dict.get(prefix + "attn.q_proj.weight")
|
||||
k_weights = hf_gpt_j_state_dict.get(prefix + "attn.k_proj.weight")
|
||||
v_weights = hf_gpt_j_state_dict.get(prefix + "attn.v_proj.weight")
|
||||
qkv_weights = torch.cat((q_weights, k_weights, v_weights))
|
||||
layer = attrgetter("attention.qkv.weight")(
|
||||
xtrt_llm_gpt_j.layers[layer_idx])
|
||||
setattr(layer, "value", qkv_weights.to(torch_dtype).cpu().numpy())
|
||||
if scaling_factors:
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.qkv.activation_scaling_factor.value = np.array(
|
||||
[scaling_factors['qkv_act'][layer_idx]], dtype=np.float32)
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.qkv.weights_scaling_factor.value = np.array(
|
||||
[scaling_factors['qkv_weights'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
if quant_mode.has_fp8_kv_cache():
|
||||
if scaling_factors:
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.kv_orig_quant_scale.value = np.array(
|
||||
[scaling_factors['qkv_output'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.kv_quant_orig_scale.value = np.array(
|
||||
[1.0 / scaling_factors['qkv_output'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
# Attention Dense (out_proj) Linear
|
||||
v = hf_gpt_j_state_dict.get(prefix + "attn.out_proj.weight")
|
||||
layer = attrgetter("attention.dense.weight")(
|
||||
xtrt_llm_gpt_j.layers[layer_idx])
|
||||
setattr(layer, "value", v.to(torch_dtype).cpu().numpy())
|
||||
if scaling_factors:
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.dense.activation_scaling_factor.value = np.array(
|
||||
[scaling_factors['dense_act'][layer_idx]], dtype=np.float32)
|
||||
xtrt_llm_gpt_j.layers[
|
||||
layer_idx].attention.dense.weights_scaling_factor.value = np.array(
|
||||
[scaling_factors['dense_weights'][layer_idx]],
|
||||
dtype=np.float32)
|
||||
|
||||
v = hf_gpt_j_state_dict.get('transformer.ln_f.weight')
|
||||
xtrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
v = hf_gpt_j_state_dict.get('transformer.ln_f.bias')
|
||||
xtrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
v = hf_gpt_j_state_dict.get('lm_head.weight')
|
||||
xtrt_llm_gpt_j.lm_head.weight.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
v = hf_gpt_j_state_dict.get('lm_head.bias')
|
||||
xtrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
|
||||
|
||||
def AWQ_quantize_pack_preprocess(weight, scale, group_size, packer,
|
||||
preprocessor):
|
||||
scale = scale.repeat_interleave(group_size, dim=0)
|
||||
weight = weight / scale
|
||||
weight = torch.round(weight).char()
|
||||
weight = torch.where(weight > 7, 7, weight)
|
||||
qweight_int8 = torch.where(weight < -8, -8, weight)
|
||||
int4_weight = packer(qweight_int8.cpu())
|
||||
int4_weight = preprocessor(int4_weight, torch.quint4x2)
|
||||
return int4_weight.view(torch.float32).cpu().numpy()
|
||||
|
||||
|
||||
def process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype):
|
||||
weight = awq_gpt_j[mPrefix + ".weight"].T.contiguous()
|
||||
[k, n] = weight.shape
|
||||
amax = awq_gpt_j[mPrefix + ".weight_quantizer._amax"].reshape(
|
||||
(n, int(k / group_size))).T.contiguous()
|
||||
pre_quant_scale = awq_gpt_j[mPrefix +
|
||||
".input_quantizer._pre_quant_scale"].reshape(
|
||||
(1, k))
|
||||
scale = amax / 8.0
|
||||
mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale, group_size,
|
||||
packer, preprocessor)
|
||||
mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
|
||||
mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy()
|
||||
|
||||
|
||||
def deSmooth(weight, pre_quant_scale):
|
||||
[k, n] = weight.shape
|
||||
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
|
||||
0).contiguous()
|
||||
weight = weight * pre_quant_scale
|
||||
return weight
|
||||
|
||||
|
||||
def reSmooth(weight, pre_quant_scale):
|
||||
[k, n] = weight.shape
|
||||
pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1,
|
||||
0).contiguous()
|
||||
weight = weight / pre_quant_scale
|
||||
return weight
|
||||
|
||||
|
||||
def get_scale(weight, group_size):
|
||||
weight = weight.T.contiguous()
|
||||
[n, k] = weight.shape
|
||||
weight = weight.reshape(n, int(k / group_size), group_size)
|
||||
weight = torch.abs(weight.reshape(-1, group_size))
|
||||
amax, idx = weight.max(1)
|
||||
amax = amax.reshape(n, int(k / group_size)).T.contiguous()
|
||||
return amax / 8
|
||||
|
||||
|
||||
def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale,
|
||||
group_size):
|
||||
weight = deSmooth(weight, pre_quant_scale)
|
||||
weight = reSmooth(weight, avg_pre_quant_scale)
|
||||
scale = get_scale(weight, group_size)
|
||||
return weight, scale
|
||||
|
||||
|
||||
def process_and_assign_qkv_weight(awq_gpt_j, prefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype):
|
||||
q_weight = awq_gpt_j[prefix + "attn.q_proj.weight"].T.contiguous()
|
||||
k_weight = awq_gpt_j[prefix + "attn.k_proj.weight"].T.contiguous()
|
||||
v_weight = awq_gpt_j[prefix + "attn.v_proj.weight"].T.contiguous()
|
||||
[k, n] = q_weight.shape
|
||||
|
||||
q_pre_quant_scale = awq_gpt_j[
|
||||
prefix + "attn.q_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
||||
k_pre_quant_scale = awq_gpt_j[
|
||||
prefix + "attn.k_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
||||
v_pre_quant_scale = awq_gpt_j[
|
||||
prefix + "attn.v_proj.input_quantizer._pre_quant_scale"].reshape((1, k))
|
||||
|
||||
qkv_pre_quant_scale = (q_pre_quant_scale + k_pre_quant_scale +
|
||||
v_pre_quant_scale) / 3.0
|
||||
q_weight, q_scale = reSmooth_and_get_scale(q_weight, q_pre_quant_scale,
|
||||
qkv_pre_quant_scale, group_size)
|
||||
k_weight, k_scale = reSmooth_and_get_scale(k_weight, k_pre_quant_scale,
|
||||
qkv_pre_quant_scale, group_size)
|
||||
v_weight, v_scale = reSmooth_and_get_scale(v_weight, v_pre_quant_scale,
|
||||
qkv_pre_quant_scale, group_size)
|
||||
|
||||
qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
|
||||
qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
|
||||
mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(
|
||||
torch_dtype).cpu().numpy()
|
||||
mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale,
|
||||
group_size, packer,
|
||||
preprocessor)
|
||||
mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
|
||||
|
||||
|
||||
def load_from_awq_gpt_j(xtrt_llm_gpt_j: GPTJForCausalLM,
|
||||
awq_gpt_j,
|
||||
config,
|
||||
dtype="float16",
|
||||
group_size=128):
|
||||
|
||||
awq_gptj_block_names = [
|
||||
"ln_1.weight",
|
||||
"ln_1.bias",
|
||||
"mlp.fc_in.bias",
|
||||
"mlp.fc_out.bias",
|
||||
]
|
||||
|
||||
xtrt_llm_model_gptj_block_names = [
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"mlp.fc.bias",
|
||||
"mlp.proj.bias",
|
||||
]
|
||||
|
||||
getattr(xtrt_llm_gpt_j, 'quant_mode', QuantMode(0))
|
||||
|
||||
packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
|
||||
preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
|
||||
|
||||
xtrt_llm.logger.info('Loading weights from AWQ GPT-J...')
|
||||
tik = time.time()
|
||||
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
|
||||
#check if we need to pad vocab
|
||||
v = awq_gpt_j.get('transformer.wte.weight')
|
||||
[vocab_size, k] = v.shape
|
||||
pad_vocab = False
|
||||
pad_vocab_size = vocab_size
|
||||
if vocab_size % 64 != 0:
|
||||
pad_vocab = True
|
||||
pad_vocab_size = int((vocab_size + 63) / 64) * 64
|
||||
if pad_vocab:
|
||||
new_v = torch.zeros([pad_vocab_size, k])
|
||||
new_v[:vocab_size, :] = v
|
||||
v = new_v
|
||||
xtrt_llm_gpt_j.embedding.weight.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
n_layer = config["n_layer"]
|
||||
|
||||
for layer_idx in range(n_layer):
|
||||
prefix = "transformer.h." + str(layer_idx) + "."
|
||||
xtrt_llm.logger.info(f'Process weights in layer: {layer_idx}')
|
||||
for idx, awq_attr in enumerate(awq_gptj_block_names):
|
||||
v = awq_gpt_j[prefix + awq_attr]
|
||||
layer = attrgetter(xtrt_llm_model_gptj_block_names[idx])(
|
||||
xtrt_llm_gpt_j.layers[layer_idx])
|
||||
setattr(layer, 'value', v.to(torch_dtype).cpu().numpy())
|
||||
|
||||
# Attention QKV Linear
|
||||
# concatenate the Q, K, V layers weights.
|
||||
process_and_assign_qkv_weight(
|
||||
awq_gpt_j, prefix,
|
||||
xtrt_llm_gpt_j.layers[layer_idx].attention.qkv, group_size,
|
||||
packer, preprocessor, torch_dtype)
|
||||
|
||||
# Attention Dense (out_proj) Linear
|
||||
mPrefix = prefix + "attn.out_proj"
|
||||
mOp = xtrt_llm_gpt_j.layers[layer_idx].attention.dense
|
||||
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype)
|
||||
|
||||
# MLP Dense (mlp.fc) Linear
|
||||
mPrefix = prefix + "mlp.fc_in"
|
||||
mOp = xtrt_llm_gpt_j.layers[layer_idx].mlp.fc
|
||||
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype)
|
||||
|
||||
# MLP Desne (mlp.proj) Linear
|
||||
mPrefix = prefix + "mlp.fc_out"
|
||||
mOp = xtrt_llm_gpt_j.layers[layer_idx].mlp.proj
|
||||
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype)
|
||||
|
||||
v = awq_gpt_j['transformer.ln_f.weight']
|
||||
xtrt_llm_gpt_j.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
v = awq_gpt_j['transformer.ln_f.bias']
|
||||
xtrt_llm_gpt_j.ln_f.bias.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
#lm_head
|
||||
if pad_vocab:
|
||||
weight = awq_gpt_j['lm_head.weight']
|
||||
[vocab_size, k] = weight.shape
|
||||
new_weight = torch.zeros([pad_vocab_size, k])
|
||||
new_weight[:vocab_size, :] = weight
|
||||
new_weight = new_weight.T.contiguous()
|
||||
amax = awq_gpt_j['lm_head.weight_quantizer._amax'].reshape(
|
||||
[vocab_size, int(k / group_size)])
|
||||
new_amax = torch.ones([pad_vocab_size, int(k / group_size)])
|
||||
new_amax[:vocab_size, :] = amax
|
||||
new_amax = new_amax.T.contiguous()
|
||||
new_scale = new_amax / 8
|
||||
xtrt_llm_gpt_j.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
|
||||
new_weight, new_scale, group_size, packer, preprocessor)
|
||||
xtrt_llm_gpt_j.lm_head.scale.value = new_scale.to(
|
||||
torch_dtype).cpu().numpy()
|
||||
xtrt_llm_gpt_j.lm_head.pre_quant_scale.value = awq_gpt_j[
|
||||
'lm_head.input_quantizer._pre_quant_scale'].to(
|
||||
torch_dtype).cpu().numpy()
|
||||
|
||||
bias = awq_gpt_j['lm_head.bias']
|
||||
new_bias = torch.zeros([pad_vocab_size])
|
||||
new_bias[:vocab_size] = bias
|
||||
xtrt_llm_gpt_j.lm_head.bias.value = new_bias.to(
|
||||
torch_dtype).cpu().numpy()
|
||||
else:
|
||||
mPrefix = "lm_head"
|
||||
mOp = xtrt_llm_gpt_j.lm_head
|
||||
process_and_assign_weight(awq_gpt_j, mPrefix, mOp, group_size, packer,
|
||||
preprocessor, torch_dtype)
|
||||
|
||||
v = awq_gpt_j['lm_head.bias']
|
||||
xtrt_llm_gpt_j.lm_head.bias.value = v.to(torch_dtype).cpu().numpy()
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
Reference in New Issue
Block a user