add pkgs
This commit is contained in:
2
examples/bloom/.gitignore
vendored
Normal file
2
examples/bloom/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
__pycache__/
|
||||
bloom/
|
||||
131
examples/bloom/README.md
Normal file
131
examples/bloom/README.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# BLOOM
|
||||
|
||||
This document shows how to build and run a BLOOM model in XTRT-LLM on both single XPU and single node multi-XPU.
|
||||
|
||||
## Overview
|
||||
|
||||
The XTRT-LLM BLOOM example code is located in [`examples/bloom`](./). There are several main files in that folder:
|
||||
|
||||
* [`build.py`](./build.py) to build the XTRT engine(s) needed to run the BLOOM model,
|
||||
* [`run.py`](./run.py) to run the inference on an input text,
|
||||
* [`summarize.py`](./summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
* INT8 & INT4 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## Usage
|
||||
|
||||
The XTRT-LLM BLOOM example code locates at [examples/bloom](./). It takes HF weights as input, and builds the corresponding XTRT engines. The number of XTRT engines depends on the number of XPUs used to run inference.
|
||||
|
||||
### Build XTRT engine(s)
|
||||
|
||||
Need to prepare the HF BLOOM checkpoint first by following the guides here https://huggingface.co/docs/transformers/main/en/model_doc/bloom.
|
||||
|
||||
e.g. To install BLOOM-560M
|
||||
|
||||
```bash
|
||||
# Setup git-lfs
|
||||
git lfs install
|
||||
rm -rf ./downloads/bloom/560M/
|
||||
mkdir -p ./downloads/bloom/560M/ && git clone https://huggingface.co/bigscience/bloom-560m ./downloads/bloom/560M/
|
||||
```
|
||||
|
||||
XTRT-LLM BLOOM builds XTRT engine(s) from HF checkpoint.
|
||||
|
||||
Normally `build.py` only requires single XPU, but if you've already got all the XPUs needed for inference, you could enable parallel building to make the engine building process faster by adding `--parallel_build` argument. Please note that currently `parallel_build` feature only supports single node.
|
||||
|
||||
Here're some examples:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# Try use_gemm_plugin to prevent accuracy issue. TODO check this holds for BLOOM
|
||||
|
||||
# Single XPU on BLOOM 560M
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Build the BLOOM 560M using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--weight_only_precision int8 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Use 2-way tensor parallelism on BLOOM 560M
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
Unlike the FP16 build where the HF weights are processed and loaded into the XTRT-LLM directly, the SmoothQuant needs to load INT8 weights which should be pre-processed before building an engine.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
python3 hf_bloom_convert.py -i ./downloads/bloom/560M/ -o ./downloads/bloom-smooth/560M --smoothquant 0.5 --tensor-parallelism 1 --storage-type float16
|
||||
```
|
||||
Note `hf_bloom_convert.py` run with pytorch, and
|
||||
1. `torch-cpu` has better accuracy than XPyTorch generally.
|
||||
2. XPyTorch often use more than 32GB GM, thus more XPU are necessary to finish it.
|
||||
3. add `-p=1` if run with XPyTorch.
|
||||
|
||||
[`build.py`](./build.py) add new options for the support of INT8 inference of SmoothQuant models.
|
||||
|
||||
`--use_smooth_quant` is the starting point of INT8 inference. By default, it
|
||||
will run the model in the _per-tensor_ mode.
|
||||
|
||||
`--per-token` and `--per-channel` are not supported yet.
|
||||
|
||||
Examples of build invocations:
|
||||
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --bin_model_dir=./downloads/bloom-smooth/560M/1-XPU \
|
||||
--use_smooth_quant \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom-smooth/560M/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
|
||||
Note that GPT attention plugin is required to be enabled for SmoothQuant for now.
|
||||
|
||||
|
||||
Note we use `--bin_model_dir` instead of `--model_dir` since SmoothQuant model needs INT8 weights and various scales from the binary files.
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./downloads/bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./downloads/bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom-smooth/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
132
examples/bloom/README_CN.md
Normal file
132
examples/bloom/README_CN.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# BLOOM
|
||||
|
||||
本文档介绍了如何使用昆仑芯XTRT-LLM在单XPU和单节点多XPU上使用昆仑芯XTRT-LLM构建和运行BLOOM模型。
|
||||
|
||||
## 概述
|
||||
|
||||
XTRT-LLM BLOOM示例代码位于 [`examples/bloom`](./). 此文件夹中有以下几个主要文件:
|
||||
|
||||
* [`build.py`](./build.py) 构建运行BLOOM模型所需的XTRT引擎
|
||||
* [`run.py`](./run.py) 基于输入的文字进行推理
|
||||
* [`summarize.py`](./summarize.py) 使用此模型对[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) 数据集中的文章进行总结
|
||||
|
||||
## 支持的矩阵
|
||||
|
||||
* FP16
|
||||
* INT8 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## 使用说明
|
||||
|
||||
XTRT-LLM BLOOM示例代码位于[examples/bloom](./)。它使用HF权重作为输入,并且构建对应的XTRT引擎。XTRT引擎的数量取决于为了运行推理而是用的XPU个数。
|
||||
|
||||
### 构建XTRT引擎
|
||||
|
||||
需要先按照下面的指南准备HF BLOOM checkpoint:https://huggingface.co/docs/transformers/main/en/model_doc/bloom。
|
||||
|
||||
举例:安装BLOOM-560M
|
||||
|
||||
```bash
|
||||
# Setup git-lfs
|
||||
git lfs install
|
||||
rm -rf ./downloads/bloom/560M/
|
||||
mkdir -p ./downloads/bloom/560M/ && git clone https://huggingface.co/bigscience/bloom-560m ./downloads/bloom/560M/
|
||||
```
|
||||
|
||||
XTRT-LLM BLOOM从HF checkpoint构建XTRT引擎。
|
||||
|
||||
通常 `build.py`只需要单个XPU,但如果您已经获得了推理所需的所有XPU,则可以通过添加 `--parallel_build` 参数来启用并行构建,从而加快引擎构建过程。请注意,目前`parallel_build`仅支持单个节点XPU。
|
||||
|
||||
以下为示例:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# Try use_gemm_plugin to prevent accuracy issue. TODO check this holds for BLOOM
|
||||
|
||||
# Single XPU on BLOOM 560M
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Build the BLOOM 560M using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--weight_only_precision int8 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Use 2-way tensor parallelism on BLOOM 560M
|
||||
python build.py --model_dir ./downloads/bloom/560M/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom/560M/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
|
||||
与FP16的HF权重可以直接被处理并加载到XTRT-LLM不同,SmoothQuant需要加载INT8权重,而INT8权重在构建引擎之前需要进行预处理。
|
||||
|
||||
示例:
|
||||
```bash
|
||||
python3 hf_bloom_convert.py -i ./downloads/bloom/560M/ -o ./downloads/bloom-smooth/560M --smoothquant 0.5 --tensor-parallelism 1 --storage-type float16
|
||||
```
|
||||
|
||||
注意:使用PyTorch运行`hf_bloom_convert.py`,并且
|
||||
|
||||
1. 'torch-cpu' 通常比XPyTorch精度更高
|
||||
2. XPyTorch 通常使用超过32GB的GM,因此需要更多的XPU来完成它
|
||||
3. 使用XPyTorch运行时,请添加`-p=1`
|
||||
|
||||
`build.py`增加了新的选项来支持SmoothQuant模型的INT8推理。
|
||||
|
||||
`--use_smooth_quant` 是INT8推理的起点。默认情况下,它将以`--per-token`模式运行模型。
|
||||
`--per-token`和`--per-channel`目前还不支持。
|
||||
|
||||
构建调用示例:
|
||||
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --bin_model_dir=./downloads/bloom-smooth/560M/1-XPU \
|
||||
--use_smooth_quant \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/bloom-smooth/560M/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
|
||||
注意:目前,SmoothQuant需要启用GPT attention插件。
|
||||
|
||||
注意:我们使用`--bin_model_dir`而不是`--model_dir`,因为SmoothQuant模型需要INT8权重和二进制文件中的各种scales。
|
||||
|
||||
### 运行
|
||||
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./downloads/bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--hf_model_dir ./downloads/bloom/560M/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom-smooth/560M/trt_engines/fp16/1-XPU/
|
||||
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python run.py --tokenizer_dir ./downloads/bloom/560M/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/bloom/560M/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
521
examples/bloom/build.py
Normal file
521
examples/bloom/build.py
Normal file
@@ -0,0 +1,521 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_xtrt
|
||||
from xtrt_llm.builder import Builder
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import smooth_quantize, weight_only_quantize
|
||||
from xtrt_llm.network import net_guard
|
||||
from xtrt_llm.plugin.plugin import ContextFMHAType
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
from weight import load_from_hf_bloom, load_from_bin, parse_config, check_embedding_share # isort:skip
|
||||
|
||||
MODEL_NAME = "bloom"
|
||||
|
||||
import onnx
|
||||
import tvm.tensorrt as trt
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
def trt_dtype_to_onnx(dtype):
|
||||
if dtype == trt.float16:
|
||||
return TensorProto.DataType.FLOAT16
|
||||
elif dtype == trt.float32:
|
||||
return TensorProto.DataType.FLOAT
|
||||
elif dtype == trt.int32:
|
||||
return TensorProto.DataType.INT32
|
||||
else:
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
def to_onnx(network, path):
|
||||
inputs = []
|
||||
for i in range(network.num_inputs):
|
||||
network_input = network.get_input(i)
|
||||
inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
||||
list(network_input.shape)))
|
||||
|
||||
outputs = []
|
||||
for i in range(network.num_outputs):
|
||||
network_output = network.get_output(i)
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
||||
list(network_output.shape)))
|
||||
|
||||
nodes = []
|
||||
for i in range(network.num_layers):
|
||||
layer = network.get_layer(i)
|
||||
layer_inputs = []
|
||||
for j in range(layer.num_inputs):
|
||||
ipt = layer.get_input(j)
|
||||
if ipt is not None:
|
||||
layer_inputs.append(layer.get_input(j).name)
|
||||
layer_outputs = [
|
||||
layer.get_output(j).name for j in range(layer.num_outputs)
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(str(layer.type),
|
||||
name=layer.name,
|
||||
inputs=layer_inputs,
|
||||
outputs=layer_outputs,
|
||||
domain="com.nvidia"))
|
||||
|
||||
onnx_model = helper.make_model(helper.make_graph(nodes,
|
||||
'attention',
|
||||
inputs,
|
||||
outputs,
|
||||
initializer=None),
|
||||
producer_name='NVIDIA')
|
||||
onnx.save(onnx_model, path)
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, rank):
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
engine.serialize(path)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now')
|
||||
parser.add_argument('--model_dir', type=str, default=None)
|
||||
parser.add_argument('--bin_model_dir', type=str, default=None)
|
||||
parser.add_argument('--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16'])
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
default='model.cache',
|
||||
help=
|
||||
'The path of to read timing cache from, will be ignored if the file does not exist'
|
||||
)
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--vocab_size', type=int, default=250680)
|
||||
parser.add_argument('--n_layer', type=int, default=32)
|
||||
parser.add_argument('--n_positions', type=int, default=2048)
|
||||
parser.add_argument('--n_embd', type=int, default=4096)
|
||||
parser.add_argument('--n_head', type=int, default=32)
|
||||
parser.add_argument('--mlp_hidden_size', type=int, default=None)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--max_input_len', type=int, default=1024)
|
||||
parser.add_argument('--max_output_len', type=int, default=1024)
|
||||
parser.add_argument('--max_beam_width', type=int, default=1)
|
||||
parser.add_argument('--use_gpt_attention_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--use_gemm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--use_layernorm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
choices=['float16', 'float32'],
|
||||
help=
|
||||
"Activates layernorm plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
||||
)
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument('--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='bloom_outputs',
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
# Arguments related to the quantization of the model.
|
||||
parser.add_argument(
|
||||
'--use_smooth_quant',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
|
||||
'See --per_channel and --per_token for finer-grained quantization options.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=0,
|
||||
choices=[0, 1],
|
||||
help=
|
||||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_embedding_sharing',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'Try to reduce the engine size by sharing the embedding lookup table between two layers.'
|
||||
'Note: the flag might not take effect when the criteria are not met.')
|
||||
parser.add_argument(
|
||||
'--use_lookup_plugin',
|
||||
nargs='?',
|
||||
const=None,
|
||||
default=False,
|
||||
choices=['float16', 'float32', 'bfloat16'],
|
||||
help="Activates the lookup plugin which enables embedding sharing.")
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
if args.model_dir is not None:
|
||||
hf_config = BloomConfig.from_pretrained(args.model_dir)
|
||||
args.n_embd = hf_config.hidden_size
|
||||
args.n_head = hf_config.num_attention_heads
|
||||
args.n_layer = hf_config.num_hidden_layers
|
||||
args.vocab_size = hf_config.vocab_size
|
||||
elif args.bin_model_dir is not None:
|
||||
logger.info(f"Setting model configuration from {args.bin_model_dir}.")
|
||||
n_embd, n_head, n_layer, vocab_size, _, rotary_pct, bias, inter_size, multi_query_mode, dtype, prompt_num_tasks, prompt_max_vocab_size = parse_config(
|
||||
Path(args.bin_model_dir) / "config.ini")
|
||||
args.n_embd = n_embd
|
||||
args.n_head = n_head
|
||||
args.n_layer = n_layer
|
||||
args.vocab_size = vocab_size
|
||||
|
||||
assert not (
|
||||
args.use_smooth_quant and args.use_weight_only
|
||||
), "You cannot enable both SmoothQuant and INT8 weight-only together."
|
||||
|
||||
if args.use_smooth_quant:
|
||||
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
|
||||
args.per_channel)
|
||||
elif args.use_weight_only:
|
||||
args.quant_mode = QuantMode.use_weight_only(
|
||||
args.weight_only_precision == 'int4')
|
||||
else:
|
||||
args.quant_mode = QuantMode(0)
|
||||
|
||||
if args.int8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_int8_kv_cache()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_rank_engine(builder: Builder,
|
||||
builder_config: xtrt_llm.builder.BuilderConfig,
|
||||
engine_name, rank, args):
|
||||
'''
|
||||
@brief: Build the engine on the given rank.
|
||||
@param rank: The rank to build the engine.
|
||||
@param args: The cmd line arguments.
|
||||
@return: The built engine.
|
||||
'''
|
||||
kv_dtype = str_dtype_to_xtrt(args.dtype)
|
||||
|
||||
# Share_embedding_table can be set True only when:
|
||||
# 1) the weight for lm_head() does not exist while other weights exist
|
||||
# 2) For multiple-processes, use_parallel_embedding=True and embedding_sharding_dim == 0.
|
||||
# Besides, for TensorRT 9.0, we can observe the engine size reduction when the lookup and gemm plugin are enabled.
|
||||
share_embedding_table = False
|
||||
if args.use_embedding_sharing:
|
||||
if args.world_size > 1:
|
||||
if args.model_dir is not None and args.embedding_sharding_dim == 0 and args.use_parallel_embedding:
|
||||
share_embedding_table = check_embedding_share(args.model_dir)
|
||||
else:
|
||||
if args.model_dir is not None:
|
||||
share_embedding_table = check_embedding_share(args.model_dir)
|
||||
|
||||
if not share_embedding_table:
|
||||
logger.warning(f'Cannot share the embedding lookup table.')
|
||||
|
||||
if share_embedding_table:
|
||||
logger.info(
|
||||
'Engine will share embedding and language modeling weights.')
|
||||
|
||||
# Initialize Module
|
||||
xtrt_llm_bloom = xtrt_llm.models.BloomForCausalLM(
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
hidden_size=args.n_embd,
|
||||
vocab_size=args.vocab_size,
|
||||
max_position_embeddings=args.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mapping=Mapping(world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.world_size), # TP only
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
embedding_sharding_dim=args.embedding_sharding_dim,
|
||||
share_embedding_table=share_embedding_table,
|
||||
quant_mode=args.quant_mode)
|
||||
if args.use_smooth_quant:
|
||||
xtrt_llm_bloom = smooth_quantize(xtrt_llm_bloom, args.quant_mode)
|
||||
elif args.use_weight_only and 0:
|
||||
xtrt_llm_bloom = weight_only_quantize(xtrt_llm_bloom, args.quant_mode)
|
||||
|
||||
if args.model_dir is not None:
|
||||
logger.info(f'Loading HF BLOOM ... from {args.model_dir}')
|
||||
tik = time.time()
|
||||
hf_bloom = BloomForCausalLM.from_pretrained(args.model_dir,
|
||||
torch_dtype="auto")
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'HF BLOOM loaded. Total time: {t}')
|
||||
print(hf_bloom)
|
||||
load_from_hf_bloom(xtrt_llm_bloom,
|
||||
hf_bloom,
|
||||
rank,
|
||||
args.world_size,
|
||||
fp16=(args.dtype == 'float16'),
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
sharding_dim=args.embedding_sharding_dim,
|
||||
share_embedding_table=share_embedding_table)
|
||||
elif args.bin_model_dir is not None:
|
||||
load_from_bin(xtrt_llm_bloom,
|
||||
args.bin_model_dir,
|
||||
rank,
|
||||
args.world_size,
|
||||
args.dtype,
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
sharding_dim=args.embedding_sharding_dim,
|
||||
share_embedding_table=share_embedding_table)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
if args.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
if args.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
dtype=args.use_layernorm_plugin)
|
||||
if args.use_lookup_plugin:
|
||||
# Use the plugin for the embedding parallelism
|
||||
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_layernorm_quantization_plugin(
|
||||
dtype=args.dtype)
|
||||
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
elif args.use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
if args.quant_mode.is_weight_only():
|
||||
builder_config.trt_builder_config.use_weight_only = args.weight_only_precision
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype)
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(xtrt_llm_bloom.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = xtrt_llm_bloom.prepare_inputs(args.max_batch_size,
|
||||
args.max_input_len,
|
||||
args.max_output_len, True,
|
||||
args.max_beam_width)
|
||||
xtrt_llm_bloom(*inputs)
|
||||
if args.enable_debug_output:
|
||||
# mark intermediate nodes' outputs
|
||||
for k, v in xtrt_llm_bloom.named_network_outputs():
|
||||
v = v.trt_tensor
|
||||
v.name = k
|
||||
network.trt_network.mark_output(v)
|
||||
v.dtype = kv_dtype
|
||||
if args.visualize:
|
||||
model_path = os.path.join(args.output_dir, 'test.onnx')
|
||||
to_onnx(network.trt_network, model_path)
|
||||
|
||||
# xtrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
engine = None
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
if rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
# when doing serializing build, all ranks share one engine
|
||||
builder = Builder()
|
||||
|
||||
cache = None
|
||||
for cur_rank in range(args.world_size):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
|
||||
int8_trt_flag = args.quant_mode.has_act_and_weight_quant(
|
||||
) or args.quant_mode.has_int8_kv_cache()
|
||||
builder_config = builder.create_builder_config(
|
||||
name=MODEL_NAME,
|
||||
precision=args.dtype,
|
||||
timing_cache=args.timing_cache if cache is None else cache,
|
||||
tensor_parallel=args.world_size, # TP only
|
||||
parallel_build=args.parallel_build,
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
hidden_size=args.n_embd,
|
||||
inter_size=args.mlp_hidden_size,
|
||||
vocab_size=args.vocab_size,
|
||||
max_position_embeddings=args.n_positions,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
int8=(args.quant_mode.has_act_and_weight_quant()
|
||||
or args.quant_mode.has_int8_kv_cache()),
|
||||
fusion_pattern_list=["remove_dup_mask"],
|
||||
quant_mode=args.quant_mode)
|
||||
guard = xtrt_llm.fusion_patterns.FuseonPatternGuard()
|
||||
print(guard)
|
||||
builder_config.trt_builder_config.builder_optimization_level = 1
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.world_size,
|
||||
cur_rank)
|
||||
engine = build_rank_engine(builder, builder_config, engine_name,
|
||||
cur_rank, args)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
if cur_rank == 0:
|
||||
# Use in-memory timing cache for multiple builder passes.
|
||||
if not args.parallel_build:
|
||||
cache = builder_config.trt_builder_config.get_timing_cache()
|
||||
|
||||
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
|
||||
|
||||
# if rank == 0:
|
||||
# ok = builder.save_timing_cache(
|
||||
# builder_config, os.path.join(args.output_dir, "model.cache"))
|
||||
# assert ok, "Failed to save timing cache."
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
logger.set_level(args.log_level)
|
||||
tik = time.time()
|
||||
if args.parallel_build and args.world_size > 1 and \
|
||||
torch.cuda.device_count() >= args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
logger.info('Serially build TensorRT engines.')
|
||||
build(0, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Total time of building all {args.world_size} engines: {t}')
|
||||
283
examples/bloom/convert.py
Normal file
283
examples/bloom/convert.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Utilities for exporting a model to our custom format.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from xtrt_llm._utils import torch_to_numpy
|
||||
|
||||
|
||||
def cpu_map_location(storage, loc):
|
||||
return storage.cpu()
|
||||
|
||||
|
||||
def gpu_map_location(storage, loc):
|
||||
if loc.startswith("cuda"):
|
||||
training_gpu_idx = int(loc.split(":")[1])
|
||||
inference_gpu_idx = training_gpu_idx % torch.cuda.device_count()
|
||||
return storage.cuda(inference_gpu_idx)
|
||||
elif loc.startswith("cpu"):
|
||||
return storage.cpu()
|
||||
else:
|
||||
raise ValueError(f"Not handled {loc}")
|
||||
|
||||
|
||||
def save_val(val, dir, key, tp_num=None):
|
||||
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
|
||||
val.tofile(dir / f"model.{key}.{suffix}")
|
||||
|
||||
|
||||
def save_split(split_vals, dir, key, i, split_factor):
|
||||
for j, val in enumerate(split_vals):
|
||||
save_val(val, dir, key, i * split_factor + j)
|
||||
|
||||
|
||||
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
|
||||
"""
|
||||
This function has two purposes:
|
||||
- compute quantized weights, scaled either per-tensor or per-column
|
||||
- compute scaling factors
|
||||
|
||||
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
|
||||
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
|
||||
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
|
||||
|
||||
Here is the list of what we need (T means per-tensor, C per-column):
|
||||
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
|
||||
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
|
||||
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
|
||||
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
|
||||
to quant range (int8) (used for CUBLAS) (T, C)
|
||||
|
||||
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
|
||||
but then the model would change depending on the number of GPUs used.
|
||||
|
||||
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
|
||||
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
|
||||
"""
|
||||
|
||||
# compute weight scaling factors for fp->int8 and int8->fp
|
||||
if is_qkv and not multi_query_mode:
|
||||
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
|
||||
dim=-1, keepdims=True)[0].cpu().numpy()
|
||||
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3,
|
||||
-1).cpu().numpy()
|
||||
elif is_qkv and multi_query_mode:
|
||||
raise ValueError(
|
||||
f"Multi-query w/ int8 quant has not been supported yet")
|
||||
else:
|
||||
scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy()
|
||||
scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy()
|
||||
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
|
||||
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
|
||||
|
||||
# compute the rest of needed scaling factors
|
||||
scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item())
|
||||
scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item())
|
||||
scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.)
|
||||
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_t)
|
||||
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_c)
|
||||
if is_qkv:
|
||||
scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
|
||||
to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8)
|
||||
return {
|
||||
"weight.int8": to_i8(weights * scale_w_orig_quant_t),
|
||||
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
|
||||
"scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32),
|
||||
"scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32),
|
||||
"scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32),
|
||||
"scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32),
|
||||
"scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32),
|
||||
"scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def write_int8(vals,
|
||||
dir,
|
||||
base_key,
|
||||
split_dim,
|
||||
tp_rank,
|
||||
split_factor,
|
||||
kv_cache_only=False):
|
||||
if not kv_cache_only:
|
||||
save_split(np.split(vals["weight.int8"], split_factor, axis=split_dim),
|
||||
dir, f"{base_key}.weight.int8", tp_rank, split_factor)
|
||||
save_split(
|
||||
np.split(vals["weight.int8.col"], split_factor, axis=split_dim),
|
||||
dir, f"{base_key}.weight.int8.col", tp_rank, split_factor)
|
||||
|
||||
saved_keys_once = ["scale_y_quant_orig"]
|
||||
if not kv_cache_only:
|
||||
saved_keys_once += [
|
||||
"scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant"
|
||||
]
|
||||
# per-column scaling factors are loaded per-gpu for ColumnParallel GEMMs (QKV, FC1)
|
||||
if not kv_cache_only:
|
||||
if split_dim == -1:
|
||||
save_split(
|
||||
np.split(vals["scale_w_quant_orig.col"],
|
||||
split_factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_w_quant_orig.col", tp_rank, split_factor)
|
||||
save_split(
|
||||
np.split(vals["scale_y_accum_quant.col"],
|
||||
split_factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_y_accum_quant.col", tp_rank, split_factor)
|
||||
else:
|
||||
saved_keys_once += [
|
||||
"scale_w_quant_orig.col", "scale_y_accum_quant.col"
|
||||
]
|
||||
|
||||
if tp_rank == 0:
|
||||
for save_key in saved_keys_once:
|
||||
save_val(vals[save_key], dir, f"{base_key}.{save_key}")
|
||||
|
||||
|
||||
# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head
|
||||
# are not split as there is only one head per key/value.
|
||||
@torch.no_grad()
|
||||
def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals,
|
||||
storage_type, act_range, config):
|
||||
use_attention_nemo_shape = config.get("use_attention_nemo_shape", False)
|
||||
split_gated_activation = config.get("split_gated_activation", False)
|
||||
num_attention_heads = config.get("num_attention_heads", 0)
|
||||
tp_size = config.get("tp_size", 1)
|
||||
int8_outputs = config.get("int8_outputs", None)
|
||||
multi_query_mode = config.get("multi_query_mode", False)
|
||||
local_dim = config.get("local_dim", None)
|
||||
|
||||
save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"
|
||||
|
||||
if not isinstance(vals, list):
|
||||
vals = [vals]
|
||||
|
||||
if config.get("transpose_weights", False) and vals[0].ndim == 2:
|
||||
vals = [val.T for val in vals]
|
||||
if "layernorm.weight" in key and config.get("apply_layernorm_1p", False):
|
||||
vals = [val + 1.0 for val in vals]
|
||||
vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals]
|
||||
|
||||
if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \
|
||||
"attention.dense.bias" in key or "post_attention_layernorm.weight" in key or \
|
||||
"post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \
|
||||
"final_layernorm.weight" in key or "final_layernorm.bias" in key or \
|
||||
"word_embeddings_layernorm.weight" in key or "word_embeddings_layernorm.bias" in key:
|
||||
|
||||
# shared weights, only need to convert the weights of rank 0
|
||||
if tp_rank == 0:
|
||||
save_val(vals[0], saved_dir, key)
|
||||
|
||||
elif "attention.dense.weight" in key or "mlp.dense_4h_to_h.weight" in key:
|
||||
cat_dim = 0
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
|
||||
split_factor)
|
||||
|
||||
elif "mlp.dense_h_to_4h.weight" in key or "mlp.dense_h_to_4h.bias" in key:
|
||||
if split_gated_activation:
|
||||
splits = [np.split(val, 2, axis=-1) for val in vals]
|
||||
vals, gates = list(zip(*splits))
|
||||
cat_dim = -1
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
|
||||
split_factor)
|
||||
|
||||
if split_gated_activation:
|
||||
assert not save_int8
|
||||
prefix, dot, suffix = key.rpartition(".")
|
||||
key = prefix + ".gate" + dot + suffix
|
||||
|
||||
gate = np.concatenate(gates, axis=cat_dim)
|
||||
split_vals = np.split(gate, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
|
||||
elif "attention.query_key_value.bias" in key:
|
||||
if local_dim is None:
|
||||
local_dim = vals[0].shape[-1] // 3
|
||||
|
||||
if multi_query_mode:
|
||||
val = vals[0]
|
||||
# out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
|
||||
b_q, b_kv = np.split(val, [local_dim], axis=-1)
|
||||
b_q_split = np.split(b_q, split_factor, axis=-1)
|
||||
split_vals = [np.concatenate((i, b_kv), axis=-1) for i in b_q_split]
|
||||
else:
|
||||
if use_attention_nemo_shape:
|
||||
head_num = num_attention_heads // tp_size
|
||||
size_per_head = local_dim // num_attention_heads
|
||||
nemo_shape = (head_num, 3, size_per_head)
|
||||
vals = [val.reshape(nemo_shape) for val in vals]
|
||||
vals = [val.transpose(1, 0, 2) for val in vals]
|
||||
|
||||
vals = [val.reshape(3, local_dim) for val in vals]
|
||||
val = np.concatenate(vals, axis=-1)
|
||||
split_vals = np.split(val, split_factor, axis=-1)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
|
||||
elif "attention.query_key_value.weight" in key:
|
||||
hidden_dim = vals[0].shape[0]
|
||||
if local_dim is None:
|
||||
local_dim = vals[0].shape[-1] // 3
|
||||
if multi_query_mode:
|
||||
val = vals[0]
|
||||
# out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
|
||||
head_size = (val.shape[-1] - local_dim) // 2
|
||||
val = val.reshape(hidden_dim, local_dim + 2 * head_size)
|
||||
w_q, w_kv = np.split(val, [local_dim], axis=-1)
|
||||
w_q_split = np.split(w_q, split_factor, axis=-1)
|
||||
split_vals = [np.concatenate((i, w_kv), axis=-1) for i in w_q_split]
|
||||
else:
|
||||
if use_attention_nemo_shape:
|
||||
head_num = num_attention_heads // tp_size
|
||||
size_per_head = hidden_dim // num_attention_heads
|
||||
vals = [
|
||||
val.reshape(hidden_dim, head_num, 3, size_per_head)
|
||||
for val in vals
|
||||
]
|
||||
vals = [val.transpose(0, 2, 1, 3) for val in vals]
|
||||
|
||||
vals = [val.reshape(hidden_dim, 3, local_dim) for val in vals]
|
||||
cat_dim = -1
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if save_int8:
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
is_qkv=True,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8,
|
||||
saved_dir,
|
||||
base_key,
|
||||
cat_dim,
|
||||
tp_rank,
|
||||
split_factor,
|
||||
kv_cache_only=int8_outputs == "kv_cache_only")
|
||||
elif "attention.dense.smoother" in key or "mlp.dense_4h_to_h.smoother" in key:
|
||||
split_vals = np.split(vals[0], split_factor, axis=0)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
else:
|
||||
print(f"[WARNING] {key} not handled by converter")
|
||||
363
examples/bloom/hf_bloom_convert.py
Normal file
363
examples/bloom/hf_bloom_convert.py
Normal file
@@ -0,0 +1,363 @@
|
||||
'''
|
||||
Convert huggingface Bloom model. Use https://huggingface.co/bigscience/bloom as demo.
|
||||
'''
|
||||
import argparse
|
||||
import configparser
|
||||
import dataclasses
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as multiprocessing
|
||||
from convert import split_and_save_weight
|
||||
from smoothquant import capture_activation_range, smooth_gemm
|
||||
from tqdm import tqdm
|
||||
from transformers import BloomForCausalLM, BloomTokenizerFast
|
||||
from transformers.models.bloom.modeling_bloom import BloomBlock
|
||||
|
||||
from xtrt_llm._utils import str_dtype_to_torch, torch_to_numpy
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ProgArgs:
|
||||
out_dir: str
|
||||
in_file: str
|
||||
tensor_parallelism: int = 1
|
||||
processes: int = 4
|
||||
calibrate_kv_cache: bool = False
|
||||
smoothquant: float = None
|
||||
model: str = "bloom"
|
||||
storage_type: str = "fp32"
|
||||
dataset_cache_dir: str = None
|
||||
load_model_on_cpu: bool = False
|
||||
convert_model_on_cpu: bool = False
|
||||
|
||||
@staticmethod
|
||||
def parse(args=None) -> 'ProgArgs':
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument('--out-dir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='file name of output directory',
|
||||
required=True)
|
||||
parser.add_argument('--in-file',
|
||||
'-i',
|
||||
type=str,
|
||||
help='file name of input checkpoint file',
|
||||
required=True)
|
||||
parser.add_argument('--tensor-parallelism',
|
||||
'-tp',
|
||||
type=int,
|
||||
help='Requested tensor parallelism for inference',
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
"-p",
|
||||
type=int,
|
||||
help=
|
||||
"How many processes to spawn for conversion (default: 4). Set it to a lower value to reduce RAM usage.",
|
||||
default=4)
|
||||
parser.add_argument(
|
||||
"--calibrate-kv-cache",
|
||||
"-kv",
|
||||
action="store_true",
|
||||
help=
|
||||
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
"-sq",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||||
" to Smoothquant the model, and output int8 weights."
|
||||
" A good first try is 0.5. Must be in [0, 1]")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="bloom",
|
||||
type=str,
|
||||
help="Specify Bloom variants to convert checkpoints correctly",
|
||||
choices=["bloom"])
|
||||
parser.add_argument("--storage-type",
|
||||
"-t",
|
||||
type=str,
|
||||
default="float32",
|
||||
choices=["float32", "float16", "bfloat16"])
|
||||
parser.add_argument("--dataset-cache-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="cache dir to load the hugging face dataset")
|
||||
parser.add_argument("--load-model-on-cpu", action="store_true")
|
||||
parser.add_argument("--convert-model-on-cpu", action="store_true")
|
||||
return ProgArgs(**vars(parser.parse_args(args)))
|
||||
|
||||
|
||||
def reorder_torch_qkv_weight_or_bias(v, model, is_bias=False):
|
||||
""" Reorder the qkv weight.
|
||||
|
||||
Note that the shape of the fused QKV weights in HF is different from the
|
||||
shape that XTRT-LLM requires.
|
||||
HF: (num_heads x 3 x head_dim, hidden_size)
|
||||
XTRT-LLM: (3 x num_heads x head_dim, hidden_size)
|
||||
This is unlike to the other models in HF e.g. GPT where they have the
|
||||
same shape with XTRT-LLM, i.e., (3 x num_heads x head_dim, hidden_size). We reshape the qkv
|
||||
weight: (3 x num_heads x head_dim, hidden).
|
||||
bias : (3 x num_heads x head_dim).
|
||||
"""
|
||||
|
||||
n_head = model.transformer.num_heads
|
||||
hidden_size = model.transformer.embed_dim
|
||||
head_dim = hidden_size // n_head
|
||||
|
||||
# (3 x hidden, ...) view as (num_heads, 3, head_dim, ...)
|
||||
v = v.reshape(n_head, 3, head_dim, -1)
|
||||
# permute to (3, num_heads, head_dim, ...)
|
||||
v = v.permute((1, 0, 2, 3))
|
||||
# final shape: weight=(3 x hidden, hidden) or bias=(3 x hidden)
|
||||
if is_bias:
|
||||
return v.reshape(3 * hidden_size)
|
||||
return v.reshape(3 * hidden_size, hidden_size)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_bloom_model(model, scales, alpha, bloom_qkv_param, bloom_smoother):
|
||||
# Smooth the activation and weights with smoother = $\diag{s}$
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, BloomBlock):
|
||||
continue
|
||||
|
||||
# reorder qkv weight/bias and scales
|
||||
param = module.self_attention.query_key_value.weight
|
||||
param = reorder_torch_qkv_weight_or_bias(param, model, is_bias=False)
|
||||
|
||||
layer_name = name + ".self_attention.query_key_value"
|
||||
act_range_qkv = scales.get(layer_name)
|
||||
# (n_head x 3 x head_dim) -> (3 x n_head x head_dim)
|
||||
act_range_qkv['w'] = reorder_torch_qkv_weight_or_bias(
|
||||
act_range_qkv['w'], model, is_bias=True)
|
||||
act_range_qkv['y'] = reorder_torch_qkv_weight_or_bias(
|
||||
act_range_qkv['y'], model, is_bias=True)
|
||||
scales[layer_name] = act_range_qkv
|
||||
|
||||
# qkv_proj
|
||||
smoother = smooth_gemm(param, scales[layer_name]["x"],
|
||||
module.input_layernorm.weight,
|
||||
module.input_layernorm.bias, alpha)
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = param.abs().max(dim=1)[0]
|
||||
bloom_qkv_param[layer_name] = param
|
||||
|
||||
# dense
|
||||
# enabled for better accuracy with perf overhead of quantiztion
|
||||
layer_name = name + ".self_attention.dense"
|
||||
smoother = smooth_gemm(module.self_attention.dense.weight,
|
||||
scales[layer_name]["x"], None, None, alpha)
|
||||
bloom_smoother[layer_name] = smoother
|
||||
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.self_attention.dense.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
# fc1
|
||||
layer_name = name + ".mlp.dense_h_to_4h"
|
||||
smoother = smooth_gemm(module.mlp.dense_h_to_4h.weight,
|
||||
scales[layer_name]["x"],
|
||||
module.post_attention_layernorm.weight,
|
||||
module.post_attention_layernorm.bias, alpha)
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.mlp.dense_h_to_4h.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
# fc2
|
||||
# enabled for better accuracy with perf overhead of quantiztion
|
||||
layer_name = name + ".mlp.dense_4h_to_h"
|
||||
smoother = smooth_gemm(module.mlp.dense_4h_to_h.weight,
|
||||
scales[layer_name]["x"], None, None, alpha)
|
||||
bloom_smoother[layer_name] = smoother
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.mlp.dense_4h_to_h.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
|
||||
# Bloom uses nn.Linear for these following ops whose weight matrix is transposed compared to transformer.Conv1D
|
||||
def transpose_weights(hf_name, param):
|
||||
weight_to_transpose = [
|
||||
"self_attention.query_key_value", "self_attention.dense",
|
||||
"mlp.dense_h_to_4h", "mlp.dense_4h_to_h"
|
||||
]
|
||||
if any([k in hf_name for k in weight_to_transpose]):
|
||||
if len(param.shape) == 2:
|
||||
param = param.transpose(0, 1)
|
||||
return param
|
||||
|
||||
|
||||
def bloom_to_trt_llm_name(orig_name):
|
||||
global_weights = {
|
||||
"transformer.word_embeddings.weight": "model.wpe",
|
||||
"transformer.word_embeddings_layernorm.bias":
|
||||
"model.word_embeddings_layernorm.bias",
|
||||
"transformer.word_embeddings_layernorm.weight":
|
||||
"model.word_embeddings_layernorm.weight",
|
||||
"transformer.ln_f.bias": "model.final_layernorm.bias",
|
||||
"transformer.ln_f.weight": "model.final_layernorm.weight",
|
||||
"lm_head.weight": "model.lm_head.weight"
|
||||
}
|
||||
|
||||
if orig_name in global_weights:
|
||||
return global_weights[orig_name]
|
||||
|
||||
_, _, layer_id, *weight_name = orig_name.split(".")
|
||||
layer_id = int(layer_id)
|
||||
weight_name = "transformer." + ".".join(weight_name)
|
||||
|
||||
per_layer_weights = {
|
||||
"transformer.input_layernorm.bias": "input_layernorm.bias",
|
||||
"transformer.input_layernorm.weight": "input_layernorm.weight",
|
||||
"transformer.self_attention.query_key_value.bias":
|
||||
"attention.query_key_value.bias",
|
||||
"transformer.self_attention.query_key_value.weight":
|
||||
"attention.query_key_value.weight",
|
||||
"transformer.self_attention.dense.bias": "attention.dense.bias",
|
||||
"transformer.self_attention.dense.weight": "attention.dense.weight",
|
||||
"transformer.post_attention_layernorm.bias":
|
||||
"post_attention_layernorm.bias",
|
||||
"transformer.post_attention_layernorm.weight":
|
||||
"post_attention_layernorm.weight",
|
||||
"transformer.mlp.dense_h_to_4h.bias": "mlp.dense_h_to_4h.bias",
|
||||
"transformer.mlp.dense_h_to_4h.weight": "mlp.dense_h_to_4h.weight",
|
||||
"transformer.mlp.dense_4h_to_h.bias": "mlp.dense_4h_to_h.bias",
|
||||
"transformer.mlp.dense_4h_to_h.weight": "mlp.dense_4h_to_h.weight",
|
||||
}
|
||||
return f"layers.{layer_id}.{per_layer_weights[weight_name]}"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def hf_bloom_converter(args: ProgArgs):
|
||||
infer_tp = args.tensor_parallelism
|
||||
multi_query_mode = True if args.model in ["santacoder", "starcoder"
|
||||
] else False
|
||||
saved_dir = Path(args.out_dir) / f"{infer_tp}-XPU"
|
||||
saved_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load position_embedding from rank 0
|
||||
model = BloomForCausalLM.from_pretrained(args.in_file,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
trust_remote_code=True)
|
||||
if args.load_model_on_cpu:
|
||||
model = model.float()
|
||||
model = model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
act_range = {}
|
||||
bloom_qkv_param = {}
|
||||
# smoother for inputs of self_attention.dense and mlp.dense_4h_to_h
|
||||
bloom_smoother = {}
|
||||
|
||||
if args.smoothquant is not None or args.calibrate_kv_cache:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
|
||||
"TOKENIZERS_PARALLELISM", "false")
|
||||
from datasets import load_dataset
|
||||
dataset = load_dataset("lambada",
|
||||
split="validation",
|
||||
cache_dir=args.dataset_cache_dir)
|
||||
act_range = capture_activation_range(
|
||||
model, BloomTokenizerFast.from_pretrained(args.in_file), dataset)
|
||||
if args.smoothquant is not None:
|
||||
smooth_bloom_model(model, act_range, args.smoothquant,
|
||||
bloom_qkv_param, bloom_smoother)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config["bloom"] = {}
|
||||
for key in vars(args):
|
||||
config["bloom"][key] = f"{vars(args)[key]}"
|
||||
for k, v in vars(model.config).items():
|
||||
config["bloom"][k] = f"{v}"
|
||||
config["bloom"]["storage_dtype"] = args.storage_type
|
||||
config["bloom"]["multi_query_mode"] = str(multi_query_mode)
|
||||
with open(saved_dir / "config.ini", 'w') as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
storage_type = str_dtype_to_torch(args.storage_type)
|
||||
|
||||
global_trt_llm_weights = [
|
||||
"model.wpe", "model.word_embeddings_layernorm.bias",
|
||||
"model.word_embeddings_layernorm.weight", "model.final_layernorm.bias",
|
||||
"model.final_layernorm.weight", "model.lm_head.weight"
|
||||
]
|
||||
|
||||
int8_outputs = None
|
||||
if args.calibrate_kv_cache:
|
||||
int8_outputs = "kv_cache_only"
|
||||
if args.smoothquant is not None:
|
||||
int8_outputs = "all"
|
||||
|
||||
starmap_args = []
|
||||
for name, param in model.named_parameters():
|
||||
if "weight" not in name and "bias" not in name:
|
||||
continue
|
||||
trt_llm_name = bloom_to_trt_llm_name(name)
|
||||
|
||||
if args.convert_model_on_cpu:
|
||||
param = param.cpu()
|
||||
if name.replace(".weight", "") in bloom_smoother.keys():
|
||||
smoother = bloom_smoother[name.replace(".weight", "")]
|
||||
starmap_args.append(
|
||||
(0, saved_dir, infer_tp,
|
||||
f"{trt_llm_name}.smoother".replace(".weight", ""),
|
||||
smoother.to(torch.float32), torch.float32, None, {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": multi_query_mode,
|
||||
"local_dim": None,
|
||||
}))
|
||||
|
||||
# reorder qkv weight and bias
|
||||
if "attention.query_key_value.weight" in trt_llm_name:
|
||||
if args.smoothquant is not None:
|
||||
param = bloom_qkv_param.get(name.replace(".weight", ""))
|
||||
else:
|
||||
param = reorder_torch_qkv_weight_or_bias(param,
|
||||
model,
|
||||
is_bias=False)
|
||||
if "attention.query_key_value.bias" in trt_llm_name:
|
||||
param = reorder_torch_qkv_weight_or_bias(param, model, is_bias=True)
|
||||
|
||||
param = transpose_weights(name, param)
|
||||
|
||||
if trt_llm_name in global_trt_llm_weights:
|
||||
torch_to_numpy(param.to(storage_type).cpu()).tofile(
|
||||
saved_dir / f"{trt_llm_name}.bin")
|
||||
else:
|
||||
# Needed by QKV projection weight split. With multi_query_mode one does not simply take
|
||||
# out_dim and divide it by 3 to get local_dim becuase out_dim = local_dim + 2 * head_size
|
||||
local_dim = model.transformer.h[
|
||||
0].attn.embed_dim if multi_query_mode else None
|
||||
starmap_args.append(
|
||||
(0, saved_dir, infer_tp, trt_llm_name, param.to(storage_type),
|
||||
storage_type, act_range.get(name.replace(".weight", "")), {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": multi_query_mode,
|
||||
"local_dim": local_dim
|
||||
}))
|
||||
|
||||
starmap_args = tqdm(starmap_args, desc="saving weights")
|
||||
if args.processes > 1:
|
||||
with multiprocessing.Pool(args.processes) as pool:
|
||||
pool.starmap(split_and_save_weight, starmap_args)
|
||||
else:
|
||||
# simpler for debug situations
|
||||
for starmap_arg in starmap_args:
|
||||
split_and_save_weight(*starmap_arg)
|
||||
|
||||
|
||||
def run_conversion(args: ProgArgs):
|
||||
print("\n=============== Arguments ===============")
|
||||
for key, value in vars(args).items():
|
||||
print(f"{key}: {value}")
|
||||
print("========================================")
|
||||
hf_bloom_converter(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
run_conversion(ProgArgs.parse())
|
||||
3
examples/bloom/requirements.txt
Normal file
3
examples/bloom/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
datasets~=2.3.2
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
130
examples/bloom/run.py
Normal file
130
examples/bloom/run.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import BloomTokenizerFast
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm.runtime import ModelConfig, SamplingConfig
|
||||
import numpy as np
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
EOS_TOKEN = 2
|
||||
PAD_TOKEN = 3
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='bloom_outputs')
|
||||
parser.add_argument('--tokenizer_dir',
|
||||
type=str,
|
||||
default=".",
|
||||
help="Directory containing the tokenizer.model.")
|
||||
parser.add_argument('--input_text',
|
||||
type=str,
|
||||
default='Born in north-east France, Soyer trained as a')
|
||||
parser.add_argument(
|
||||
'--performance_test_scale',
|
||||
type=str,
|
||||
help=
|
||||
"Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)",
|
||||
default="")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
dtype = config['builder_config']['precision']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == xtrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
|
||||
num_heads = config['builder_config']['num_heads'] // world_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // world_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
if world_size > 1:
|
||||
os.environ["XCCL_GROUP_ID"] = str(runtime_rank // world_size)
|
||||
os.environ["XCCL_NRANKS"] = str(world_size)
|
||||
os.environ["XCCL_CUR_RANK"] = str(runtime_rank % world_size)
|
||||
os.environ["XCCL_DEVICE_ID"] = str(runtime_rank)
|
||||
os.environ["MP_RUN"] = str(1)
|
||||
runtime_mapping = xtrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
engine_name = get_engine_name('bloom', dtype, world_size, runtime_rank)
|
||||
serialize_path = os.path.join(args.engine_dir, engine_name)
|
||||
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.tokenizer_dir)
|
||||
input_ids = torch.tensor(tokenizer.encode(args.input_text),
|
||||
dtype=torch.int32).cuda().unsqueeze(0)
|
||||
|
||||
model_config = ModelConfig(num_heads=num_heads,
|
||||
num_kv_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
dtype=dtype)
|
||||
sampling_config = SamplingConfig(end_id=EOS_TOKEN, pad_id=PAD_TOKEN)
|
||||
input_lengths = torch.tensor(
|
||||
[input_ids.size(1) for _ in range(input_ids.size(0))]).int().cuda()
|
||||
|
||||
# with open(serialize_path, 'rb') as f:
|
||||
# engine_buffer = f.read()
|
||||
decoder = xtrt_llm.runtime.GenerationSession(model_config,
|
||||
serialize_path,
|
||||
runtime_mapping)
|
||||
if args.performance_test_scale != "":
|
||||
performance_test_scale_list = args.performance_test_scale.split("E")
|
||||
for scale in performance_test_scale_list:
|
||||
xtrt_llm.logger.info(f"Running performance test with scale {scale}")
|
||||
bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")]
|
||||
_input_ids = torch.from_numpy(
|
||||
np.zeros((bs, seqlen)).astype("int32")).cuda()
|
||||
_input_lengths = torch.from_numpy(
|
||||
np.full((bs, ), seqlen).astype("int32")).cuda()
|
||||
_max_input_length = torch.max(_input_lengths).item()
|
||||
decoder.setup(_input_lengths.size(0), _max_input_length,
|
||||
_max_output_len)
|
||||
_output_gen_ids = decoder.decode(_input_ids,
|
||||
_input_lengths,
|
||||
sampling_config)
|
||||
decoder.setup(input_ids.size(0),
|
||||
max_context_length=input_ids.size(1),
|
||||
max_new_tokens=args.max_output_len)
|
||||
output_ids = decoder.decode(input_ids, input_lengths, sampling_config)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
output_ids = output_ids.tolist()[0][0][input_ids.size(1):]
|
||||
output_text = tokenizer.decode(output_ids)
|
||||
print(f'Input: \"{args.input_text}\"')
|
||||
print(f'Output Ids: \"{output_ids}\"')
|
||||
print(f'Output: \"{output_text}\"')
|
||||
141
examples/bloom/smoothquant.py
Normal file
141
examples/bloom/smoothquant.py
Normal file
@@ -0,0 +1,141 @@
|
||||
'''
|
||||
Utilities for SmoothQuant models
|
||||
'''
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_smoothing(scales,
|
||||
gemm_weights,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
dtype=torch.float32,
|
||||
layernorm_1p=False):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
|
||||
if layernorm_weights is not None:
|
||||
assert layernorm_weights.numel() == scales.numel()
|
||||
layernorm_weights.div_(scales).to(dtype)
|
||||
if layernorm_bias is not None:
|
||||
assert layernorm_bias.numel() == scales.numel()
|
||||
layernorm_bias.div_(scales).to(dtype)
|
||||
if layernorm_1p:
|
||||
layernorm_weights += (1 / scales) - 1
|
||||
|
||||
for gemm in gemm_weights:
|
||||
gemm.mul_(scales.view(1, -1)).to(dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm(gemm_weights,
|
||||
act_scales,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
|
||||
orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat(
|
||||
[fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
if ln is not None:
|
||||
ln.weight.div_(scales)
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def capture_activation_range(model,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
||||
|
||||
def stat_tensor(name, tensor, act_scales, key):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float()
|
||||
|
||||
if act_scales[name][key] is None:
|
||||
act_scales[name][key] = comming_max
|
||||
else:
|
||||
act_scales[name][key] = torch.max(act_scales[name][key],
|
||||
comming_max)
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x, act_scales, "x")
|
||||
stat_tensor(name, y, act_scales, "y")
|
||||
|
||||
if act_scales[name]["w"] is None:
|
||||
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
||||
None).max(dim=1)[0]
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
||||
hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(stat_input_hook, name=name)))
|
||||
|
||||
for i in tqdm(range(num_samples), desc="calibrating model"):
|
||||
input_ids = tokenizer(dataset[i]["text"],
|
||||
return_tensors="pt",
|
||||
max_length=seq_len,
|
||||
truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
372
examples/bloom/summarize.py
Normal file
372
examples/bloom/summarize.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO Just a copy paste, needs work
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import AutoModelForCausalLM, BloomTokenizerFast
|
||||
|
||||
import xtrt_llm as tensorrt_llm
|
||||
import xtrt_llm.profiler as profiler
|
||||
from xtrt_llm.logger import logger
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
|
||||
def TRTBloom(args, config):
|
||||
dtype = config['builder_config']['precision']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == tensorrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
num_heads = config['builder_config']['num_heads'] // world_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // world_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
use_gpt_attention_plugin = bool(
|
||||
config['plugin_config']['gpt_attention_plugin'])
|
||||
|
||||
model_config = tensorrt_llm.runtime.ModelConfig(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_heads,
|
||||
hidden_size=hidden_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
dtype=dtype)
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
runtime_mapping = tensorrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
engine_name = get_engine_name('bloom', dtype, world_size, runtime_rank)
|
||||
serialize_path = os.path.join(args.engine_dir, engine_name)
|
||||
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
profiler.start('load tensorrt_llm engine')
|
||||
'''
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
'''
|
||||
decoder = tensorrt_llm.runtime.GenerationSession(model_config,
|
||||
serialize_path,
|
||||
runtime_mapping)
|
||||
profiler.stop('load tensorrt_llm engine')
|
||||
tensorrt_llm.logger.info(
|
||||
f'Load engine takes: {profiler.elapsed_time_in_sec("load tensorrt_llm engine")} sec'
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def main(args):
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
||||
test_trt_llm = args.test_trt_llm
|
||||
hf_model_location = args.hf_model_location
|
||||
profiler.start('load tokenizer')
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(hf_model_location,
|
||||
padding_side='left')
|
||||
profiler.stop('load tokenizer')
|
||||
tensorrt_llm.logger.info(
|
||||
f'Load tokenizer takes: {profiler.elapsed_time_in_sec("load tokenizer")} sec'
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset_cnn = load_dataset("ccdv/cnn_dailymail",
|
||||
'3.0.0',
|
||||
cache_dir=args.dataset_path)
|
||||
|
||||
max_batch_size = args.batch_size
|
||||
|
||||
# runtime parameters
|
||||
# repetition_penalty = 1
|
||||
top_k = args.top_k
|
||||
output_len = 100
|
||||
test_token_num = 923
|
||||
# top_p = 0.0
|
||||
# random_seed = 5
|
||||
temperature = 1
|
||||
num_beams = args.num_beams
|
||||
|
||||
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
|
||||
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
|
||||
|
||||
if test_trt_llm:
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorrt_llm_bloom = TRTBloom(args, config)
|
||||
|
||||
if test_hf:
|
||||
profiler.start('load HF model')
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_model_location)
|
||||
profiler.stop('load HF model')
|
||||
tensorrt_llm.logger.info(
|
||||
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
|
||||
)
|
||||
if args.data_type == 'fp16':
|
||||
model.half()
|
||||
model.cuda()
|
||||
|
||||
def summarize_tensorrt_llm(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
line_encoded = []
|
||||
input_lengths = []
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
input_id = tokenizer.encode(line[i],
|
||||
return_tensors='pt').type(torch.int32)
|
||||
input_id = input_id[:, -test_token_num:]
|
||||
|
||||
line_encoded.append(input_id)
|
||||
input_lengths.append(input_id.shape[-1])
|
||||
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
max_length = max(input_lengths)
|
||||
for i in range(batch_size):
|
||||
pad_size = max_length - input_lengths[i]
|
||||
|
||||
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
|
||||
line_encoded[i] = torch.cat(
|
||||
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
|
||||
axis=-1)
|
||||
|
||||
line_encoded = torch.cat(line_encoded, axis=0).cuda()
|
||||
input_lengths = torch.tensor(input_lengths, dtype=torch.int32).cuda()
|
||||
|
||||
sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=end_id, pad_id=pad_id, top_k=top_k, num_beams=num_beams)
|
||||
|
||||
with torch.no_grad():
|
||||
tensorrt_llm_bloom.setup(line_encoded.size(0),
|
||||
max_context_length=line_encoded.size(1),
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
|
||||
output_ids = tensorrt_llm_bloom.decode(
|
||||
line_encoded,
|
||||
input_lengths,
|
||||
sampling_config,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Extract a list of tensors of shape beam_width x output_ids.
|
||||
if tensorrt_llm_bloom.mapping.is_first_pp_rank():
|
||||
output_beams_list = [
|
||||
tokenizer.batch_decode(output_ids[batch_idx, :,
|
||||
input_lengths[batch_idx]:],
|
||||
skip_special_tokens=True)
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
return output_beams_list, output_ids[:, :, max_length:].tolist()
|
||||
return [], []
|
||||
|
||||
def summarize_hf(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
if batch_size > 1:
|
||||
logger.warning(
|
||||
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
|
||||
)
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
line_encoded = tokenizer(line,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True)["input_ids"].type(torch.int64)
|
||||
|
||||
line_encoded = line_encoded[:, -test_token_num:]
|
||||
line_encoded = line_encoded.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(line_encoded,
|
||||
max_length=len(line_encoded[0]) +
|
||||
output_len,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_beams,
|
||||
early_stopping=True)
|
||||
|
||||
tokens_list = output[:, len(line_encoded[0]):].tolist()
|
||||
output = output.reshape([batch_size, num_beams, -1])
|
||||
output_lines_list = [
|
||||
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
|
||||
skip_special_tokens=True)
|
||||
for i in range(num_beams)
|
||||
]
|
||||
|
||||
return output_lines_list, tokens_list
|
||||
|
||||
if test_trt_llm:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_tensorrt_llm(datapoint)
|
||||
if runtime_rank == 0:
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
logger.info("XTRT-LLM Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
|
||||
if test_hf:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_hf(datapoint)
|
||||
logger.info("---------------------------------------------------------")
|
||||
logger.info("HF Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info("---------------------------------------------------------")
|
||||
|
||||
metric_tensorrt_llm = [load_metric("rouge") for _ in range(num_beams)]
|
||||
metric_hf = [load_metric("rouge") for _ in range(num_beams)]
|
||||
for i in range(num_beams):
|
||||
metric_tensorrt_llm[i].seed = 0
|
||||
metric_hf[i].seed = 0
|
||||
|
||||
ite_count = 0
|
||||
data_point_idx = 0
|
||||
while (data_point_idx < len(dataset_cnn['test'])) and (ite_count <
|
||||
args.max_ite):
|
||||
if runtime_rank == 0:
|
||||
logger.debug(
|
||||
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
|
||||
)
|
||||
datapoint = dataset_cnn['test'][data_point_idx:(data_point_idx +
|
||||
max_batch_size)]
|
||||
|
||||
if test_trt_llm:
|
||||
profiler.start('tensorrt_llm')
|
||||
summary_tensorrt_llm, tokens_tensorrt_llm = summarize_tensorrt_llm(
|
||||
datapoint)
|
||||
profiler.stop('tensorrt_llm')
|
||||
|
||||
if test_hf:
|
||||
profiler.start('hf')
|
||||
summary_hf, tokens_hf = summarize_hf(datapoint)
|
||||
profiler.stop('hf')
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
for batch_idx in range(len(summary_tensorrt_llm)):
|
||||
for beam_idx in range(num_beams):
|
||||
metric_tensorrt_llm[beam_idx].add_batch(
|
||||
predictions=[
|
||||
summary_tensorrt_llm[batch_idx][beam_idx]
|
||||
],
|
||||
references=[datapoint['highlights'][batch_idx]])
|
||||
if test_hf:
|
||||
for beam_idx in range(num_beams):
|
||||
for batch_idx in range(len(summary_hf[beam_idx])):
|
||||
metric_hf[beam_idx].add_batch(
|
||||
predictions=[summary_hf[beam_idx][batch_idx]],
|
||||
references=[datapoint['highlights'][batch_idx]])
|
||||
|
||||
logger.debug('-' * 100)
|
||||
logger.debug(f"Article : {datapoint['article']}")
|
||||
if test_trt_llm:
|
||||
logger.debug(f'XTRT-LLM Summary: {summary_tensorrt_llm}')
|
||||
if test_hf:
|
||||
logger.debug(f'HF Summary: {summary_hf}')
|
||||
logger.debug(f"highlights : {datapoint['highlights']}")
|
||||
|
||||
data_point_idx += max_batch_size
|
||||
ite_count += 1
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'XTRT-LLM (total latency: {profiler.elapsed_time_in_sec("tensorrt_llm")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
logger.info(f"XTRT-LLM beam {beam_idx} result")
|
||||
computed_metrics_tensorrt_llm = metric_tensorrt_llm[
|
||||
beam_idx].compute()
|
||||
for key in computed_metrics_tensorrt_llm.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_tensorrt_llm[key].mid[2]*100}'
|
||||
)
|
||||
|
||||
if args.check_accuracy and beam_idx == 0:
|
||||
assert computed_metrics_tensorrt_llm['rouge1'].mid[
|
||||
2] * 100 > args.tensorrt_llm_rouge1_threshold
|
||||
if test_hf:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
logger.info(f"HF beam {beam_idx} result")
|
||||
computed_metrics_hf = metric_hf[beam_idx].compute()
|
||||
for key in computed_metrics_hf.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_hf[key].mid[2]*100}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf_model_location', type=str, default='./bloom/560M')
|
||||
parser.add_argument('--test_hf', action='store_true')
|
||||
parser.add_argument('--test_trt_llm', action='store_true')
|
||||
parser.add_argument('--data_type',
|
||||
type=str,
|
||||
choices=['fp32', 'fp16'],
|
||||
default='fp16')
|
||||
parser.add_argument('--dataset_path', type=str, default='')
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--engine_dir', type=str, default='bloom_outputs')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
||||
type=float,
|
||||
default=15.0)
|
||||
parser.add_argument('--num_beams', type=int, default=1)
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
549
examples/bloom/weight.py
Normal file
549
examples/bloom/weight.py
Normal file
@@ -0,0 +1,549 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import configparser
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_np
|
||||
from xtrt_llm.models import BloomForCausalLM
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def split(v, tp_size, idx, dim=0):
|
||||
if tp_size == 1:
|
||||
return v
|
||||
if len(v.shape) == 1:
|
||||
return np.ascontiguousarray(np.split(v, tp_size)[idx])
|
||||
else:
|
||||
return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])
|
||||
|
||||
|
||||
def reorder_qkv_weight_or_bias(v, n_head, n_hidden, is_bias=False):
|
||||
""" Reorder the qkv weight.
|
||||
|
||||
Note that the shape of the fused QKV weights in HF is different from the
|
||||
shape that XTRT-LLM requires.
|
||||
HF: (num_heads x 3 x head_dim, hidden_size)
|
||||
XTRT-LLM: (3 x num_heads x head_dim, hidden_size)
|
||||
This is unlike to the other models in HF e.g. GPT where they have the
|
||||
same shape with XTRT-LLM, i.e., (3 x num_heads x head_dim, hidden_size). Also,
|
||||
to split across attention heads in tensor parallel, we reshape the qkv
|
||||
weight: (3, num_heads x head_dim, hidden).
|
||||
bias : (3, num_heads x head_dim).
|
||||
"""
|
||||
|
||||
head_dim = n_hidden // n_head
|
||||
|
||||
# (3 x hidden, ...) view as (num_heads, 3, head_dim, ...)
|
||||
v = v.reshape(n_head, 3, head_dim, -1)
|
||||
# permute to (3, num_heads, head_dim, ...)
|
||||
v = v.transpose((1, 0, 2, 3))
|
||||
# final shape: weight=(3, hidden, hidden) or bias=(3, hidden)
|
||||
if is_bias:
|
||||
return v.reshape(3, n_hidden)
|
||||
return v.reshape(3, n_hidden, n_hidden)
|
||||
|
||||
|
||||
def split_qkv_tp(xtrt_llm_bloom, v, tensor_parallel, rank):
|
||||
"""
|
||||
Splits the QKV matrix according to tensor parallelism
|
||||
"""
|
||||
n_heads = xtrt_llm_bloom._num_heads
|
||||
hidden_size = xtrt_llm_bloom._hidden_size
|
||||
v = reorder_qkv_weight_or_bias(v, n_heads, hidden_size, is_bias=False)
|
||||
split_v = split(v, tensor_parallel, rank, dim=1)
|
||||
split_v = split_v.reshape(3 * (hidden_size // tensor_parallel), hidden_size)
|
||||
|
||||
return np.ascontiguousarray(split_v)
|
||||
|
||||
|
||||
def split_qkv_bias_tp(xtrt_llm_bloom, v, tensor_parallel, rank):
|
||||
"""
|
||||
Splits the QKV bias according to tensor parallelism
|
||||
"""
|
||||
layer = xtrt_llm_bloom.layers[0]
|
||||
n_heads = layer.num_attention_heads
|
||||
hidden_size = layer.hidden_size
|
||||
v = reorder_qkv_weight_or_bias(v, n_heads, hidden_size, is_bias=True)
|
||||
split_v = split(v, tensor_parallel, rank, dim=1)
|
||||
split_v = split_v.reshape(3 * (hidden_size // tensor_parallel))
|
||||
return np.ascontiguousarray(split_v)
|
||||
|
||||
|
||||
def split_matrix_tp(v, tensor_parallel, rank, dim):
|
||||
return np.ascontiguousarray(split(v, tensor_parallel, rank, dim=dim))
|
||||
|
||||
|
||||
def get_weight(config, prefix, dtype):
|
||||
return config[prefix + '.weight'].to(dtype).detach().cpu().numpy()
|
||||
|
||||
|
||||
def get_bias(config, prefix, dtype):
|
||||
return config[prefix + '.bias'].to(dtype).detach().cpu().numpy()
|
||||
|
||||
|
||||
def get_weight_and_bias(config, prefix, dtype):
|
||||
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
|
||||
|
||||
|
||||
def set_layer_weight(layer, val, quant_mode):
|
||||
if quant_mode.is_int8_weight_only():
|
||||
plugin_weight_only_quant_type = torch.int8
|
||||
elif quant_mode.is_int4_weight_only():
|
||||
plugin_weight_only_quant_type = torch.quint4x2
|
||||
# use_weight_only = quant_mode.is_weight_only()
|
||||
use_weight_only = 0
|
||||
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(val.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
# workaround for trt not supporting int8 inputs in plugins currently
|
||||
layer.weight.value = processed_torch_weights.view(
|
||||
dtype=torch.float32).numpy()
|
||||
layer.per_channel_scale.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
layer.weight.value = np.ascontiguousarray(val)
|
||||
|
||||
|
||||
def check_embedding_share(dir_path):
|
||||
share_embedding_table = False
|
||||
if Path(dir_path).exists():
|
||||
share_embedding_table = True
|
||||
return share_embedding_table
|
||||
|
||||
|
||||
def load_from_hf_bloom(xtrt_llm_bloom,
|
||||
hf_bloom,
|
||||
rank=0,
|
||||
tensor_parallel=1,
|
||||
fp16=False,
|
||||
use_parallel_embedding=False,
|
||||
sharding_dim=0,
|
||||
share_embedding_table=False):
|
||||
xtrt_llm.logger.info('Loading weights from HF BLOOM...')
|
||||
tik = time.time()
|
||||
|
||||
quant_mode = getattr(xtrt_llm_bloom, 'quant_mode', QuantMode(0))
|
||||
|
||||
model_params = dict(hf_bloom.named_parameters())
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
for l in range(hf_bloom.config.num_hidden_layers):
|
||||
prefix = f'transformer.h.{l}.'
|
||||
|
||||
qkv_weight, qkv_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'self_attention.query_key_value', dtype)
|
||||
split_v = split_qkv_tp(xtrt_llm_bloom, qkv_weight, tensor_parallel,
|
||||
rank)
|
||||
set_layer_weight(xtrt_llm_bloom.layers[l].attention.qkv, split_v,
|
||||
quant_mode)
|
||||
xtrt_llm_bloom.layers[
|
||||
l].attention.qkv.bias.value = split_qkv_bias_tp(
|
||||
xtrt_llm_bloom, qkv_bias, tensor_parallel, rank)
|
||||
|
||||
attn_dense_weight, attn_dense_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'self_attention.dense', dtype)
|
||||
split_v = split_matrix_tp(attn_dense_weight,
|
||||
tensor_parallel,
|
||||
rank,
|
||||
dim=1)
|
||||
set_layer_weight(xtrt_llm_bloom.layers[l].attention.dense, split_v,
|
||||
quant_mode)
|
||||
xtrt_llm_bloom.layers[
|
||||
l].attention.dense.bias.value = attn_dense_bias
|
||||
|
||||
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'mlp.dense_h_to_4h', dtype)
|
||||
split_v = split_matrix_tp(mlp_fc_weight, tensor_parallel, rank, dim=0)
|
||||
set_layer_weight(xtrt_llm_bloom.layers[l].mlp.fc, split_v,
|
||||
quant_mode)
|
||||
xtrt_llm_bloom.layers[l].mlp.fc.bias.value = split_matrix_tp(
|
||||
mlp_fc_bias, tensor_parallel, rank, dim=0)
|
||||
|
||||
mlp_proj_weight, mlp_proj_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'mlp.dense_4h_to_h', dtype)
|
||||
split_v = split_matrix_tp(mlp_proj_weight, tensor_parallel, rank, dim=1)
|
||||
set_layer_weight(xtrt_llm_bloom.layers[l].mlp.proj, split_v,
|
||||
quant_mode)
|
||||
xtrt_llm_bloom.layers[l].mlp.proj.bias.value = mlp_proj_bias
|
||||
|
||||
# Layer norms do not use tensor parallelism
|
||||
input_ln_weight, input_ln_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'input_layernorm', dtype)
|
||||
xtrt_llm_bloom.layers[
|
||||
l].input_layernorm.weight.value = input_ln_weight
|
||||
xtrt_llm_bloom.layers[l].input_layernorm.bias.value = input_ln_bias
|
||||
|
||||
post_ln_weight, post_ln_bias = get_weight_and_bias(
|
||||
model_params, prefix + 'post_attention_layernorm', dtype)
|
||||
xtrt_llm_bloom.layers[
|
||||
l].post_layernorm.weight.value = post_ln_weight
|
||||
xtrt_llm_bloom.layers[l].post_layernorm.bias.value = post_ln_bias
|
||||
|
||||
embed_w = get_weight(model_params, 'transformer.word_embeddings', dtype)
|
||||
if not share_embedding_table:
|
||||
xtrt_llm_bloom.lm_head.weight.value = split_matrix_tp(
|
||||
embed_w.copy(), tensor_parallel, rank, dim=0)
|
||||
|
||||
if not use_parallel_embedding:
|
||||
xtrt_llm_bloom.embedding.weight.value = embed_w
|
||||
else:
|
||||
assert hf_bloom.config.vocab_size % tensor_parallel == 0
|
||||
xtrt_llm_bloom.embedding.weight.value = split_matrix_tp(
|
||||
embed_w, tensor_parallel, rank, dim=sharding_dim)
|
||||
|
||||
embed_f_w, embed_f_b = get_weight_and_bias(
|
||||
model_params, 'transformer.word_embeddings_layernorm', dtype)
|
||||
xtrt_llm_bloom.ln_embed.weight.value = embed_f_w
|
||||
xtrt_llm_bloom.ln_embed.bias.value = embed_f_b
|
||||
|
||||
ln_f_w, ln_f_b = get_weight_and_bias(model_params, 'transformer.ln_f',
|
||||
dtype)
|
||||
xtrt_llm_bloom.ln_f.weight.value = ln_f_w
|
||||
xtrt_llm_bloom.ln_f.bias.value = ln_f_b
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
|
||||
|
||||
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
|
||||
suffix = f"{rank}.bin"
|
||||
if use_smooth_quant:
|
||||
sq_prefix = "int8."
|
||||
if quant_per_channel:
|
||||
sq_prefix += "col."
|
||||
suffix = sq_prefix + suffix
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_layer_idx(name):
|
||||
ss = name.split('.')
|
||||
for s in ss:
|
||||
if s.isdigit():
|
||||
return s
|
||||
return None
|
||||
|
||||
|
||||
def parse_config(ini_file):
|
||||
bloom_config = configparser.ConfigParser()
|
||||
bloom_config.read(ini_file)
|
||||
|
||||
n_embd = bloom_config.getint('bloom', 'hidden_size')
|
||||
n_head = bloom_config.getint('bloom', 'n_head')
|
||||
n_layer = bloom_config.getint('bloom', 'n_layer')
|
||||
vocab_size = bloom_config.getint('bloom', 'vocab_size')
|
||||
do_layer_norm_before = bloom_config.getboolean('bloom',
|
||||
'do_layer_norm_before',
|
||||
fallback=True)
|
||||
rotary_pct = bloom_config.getfloat('bloom', 'rotary_pct', fallback=0.0)
|
||||
bias = bloom_config.getboolean('bloom', 'bias', fallback=True)
|
||||
inter_size = bloom_config.getint('bloom',
|
||||
'intermediate_size',
|
||||
fallback=None)
|
||||
dtype = bloom_config.get('bloom', 'storage_dtype', fallback='float32')
|
||||
|
||||
if inter_size is None:
|
||||
inter_size = 4 * n_embd
|
||||
|
||||
multi_query_mode = bloom_config.getboolean('bloom',
|
||||
'multi_query_mode',
|
||||
fallback=False)
|
||||
prompt_num_tasks = bloom_config.getint('bloom',
|
||||
'prompt_num_tasks',
|
||||
fallback=0)
|
||||
prompt_max_vocab_size = bloom_config.getint('bloom',
|
||||
'prompt_max_vocab_size',
|
||||
fallback=0)
|
||||
return n_embd, n_head, n_layer, vocab_size, do_layer_norm_before, rotary_pct, bias, inter_size, multi_query_mode, dtype, prompt_num_tasks, prompt_max_vocab_size
|
||||
|
||||
|
||||
def load_from_bin(xtrt_llm_bloom: BloomForCausalLM,
|
||||
dir_path,
|
||||
rank=0,
|
||||
tensor_parallel=1,
|
||||
dtype='float32',
|
||||
use_parallel_embedding=False,
|
||||
sharding_dim=0,
|
||||
share_embedding_table=False):
|
||||
xtrt_llm.logger.info('Loading weights from bin...')
|
||||
tik = time.time()
|
||||
|
||||
quant_mode = getattr(xtrt_llm_bloom, 'quant_mode', QuantMode(0))
|
||||
if quant_mode.is_int8_weight_only():
|
||||
torch.int8
|
||||
elif quant_mode.is_int4_weight_only():
|
||||
torch.quint4x2
|
||||
n_embd, n_head, n_layer, vocab_size, do_layer_norm_before, rotary_pct, bias, inter_size, multi_query_mode, *_ = parse_config(
|
||||
Path(dir_path) / 'config.ini')
|
||||
np_dtype = str_dtype_to_np(dtype)
|
||||
|
||||
def fromfile(dir_path, name, shape=None, dtype=None):
|
||||
dtype = np_dtype if dtype is None else dtype
|
||||
p = dir_path + '/' + name
|
||||
if Path(p).exists():
|
||||
t = np.fromfile(p, dtype=dtype)
|
||||
if shape is not None:
|
||||
t = t.reshape(shape)
|
||||
return t
|
||||
return None
|
||||
|
||||
def set_smoothquant_scale_factors(module,
|
||||
pre_scale_weight,
|
||||
dir_path,
|
||||
basename,
|
||||
shape,
|
||||
per_tok_dyn,
|
||||
per_channel,
|
||||
is_qkv=False,
|
||||
rank=None):
|
||||
suffix = "bin"
|
||||
if per_channel:
|
||||
if rank is not None:
|
||||
suffix = f"{rank}." + suffix
|
||||
suffix = "col." + suffix
|
||||
|
||||
col_shape = shape if (per_channel or is_qkv) else [1, 1]
|
||||
if per_tok_dyn:
|
||||
if pre_scale_weight is not None:
|
||||
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
|
||||
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
|
||||
col_shape, np.float32)
|
||||
module.per_channel_scale.value = t
|
||||
else:
|
||||
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
|
||||
np.float32)
|
||||
pre_scale_weight.value = t
|
||||
t = fromfile(dir_path, f"{basename}scale_y_accum_quant.{suffix}",
|
||||
col_shape, np.float32)
|
||||
module.per_channel_scale.value = t
|
||||
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
|
||||
np.float32)
|
||||
module.act_scale.value = t
|
||||
|
||||
def set_smoother(module, dir_path, base_name, shape, rank):
|
||||
suffix = f"{rank}.bin"
|
||||
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
|
||||
np.float32)
|
||||
module.smoother.value = t
|
||||
|
||||
# Determine the quantization mode.
|
||||
quant_mode = getattr(xtrt_llm_bloom, "quant_mode", QuantMode(0))
|
||||
# Do we use SmoothQuant?
|
||||
use_smooth_quant = quant_mode.has_act_and_weight_quant()
|
||||
# Do we use quantization per token?
|
||||
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
|
||||
# Do we use quantization per channel?
|
||||
quant_per_channel = quant_mode.has_per_channel_scaling()
|
||||
|
||||
# Do we use INT4/INT8 weight-only?
|
||||
quant_mode.is_weight_only()
|
||||
|
||||
# Int8 KV cache
|
||||
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
|
||||
|
||||
'''
|
||||
def sq_trick(x):
|
||||
return x.view(np.float32) if use_smooth_quant else x
|
||||
'''
|
||||
|
||||
# Debug
|
||||
suffix = gen_suffix(rank, use_smooth_quant, quant_per_channel)
|
||||
# The type of weights.
|
||||
w_type = np_dtype if not use_smooth_quant else np.int8
|
||||
|
||||
vocab_embedding_weight = (fromfile(dir_path, 'model.wpe.bin',
|
||||
[vocab_size, n_embd]))
|
||||
embed_w = np.ascontiguousarray(
|
||||
split(vocab_embedding_weight.copy(), tensor_parallel, rank))
|
||||
if not share_embedding_table:
|
||||
xtrt_llm_bloom.lm_head.weight.value = embed_w
|
||||
|
||||
if not use_parallel_embedding:
|
||||
xtrt_llm_bloom.embedding.weight.value = np.ascontiguousarray(
|
||||
vocab_embedding_weight)
|
||||
else:
|
||||
assert vocab_size % tensor_parallel == 0
|
||||
xtrt_llm_bloom.embedding.weight.value = np.ascontiguousarray(
|
||||
split(vocab_embedding_weight,
|
||||
tensor_parallel,
|
||||
rank,
|
||||
dim=sharding_dim))
|
||||
|
||||
xtrt_llm_bloom.ln_embed.bias.value = (fromfile(
|
||||
dir_path, 'model.word_embeddings_layernorm.bias.bin'))
|
||||
xtrt_llm_bloom.ln_embed.weight.value = (fromfile(
|
||||
dir_path, 'model.word_embeddings_layernorm.weight.bin'))
|
||||
|
||||
xtrt_llm_bloom.ln_f.bias.value = (fromfile(
|
||||
dir_path, 'model.final_layernorm.bias.bin'))
|
||||
xtrt_llm_bloom.ln_f.weight.value = (fromfile(
|
||||
dir_path, 'model.final_layernorm.weight.bin'))
|
||||
|
||||
for i in range(n_layer):
|
||||
c_attn_out_dim = (3 * n_embd //
|
||||
tensor_parallel) if not multi_query_mode else (
|
||||
n_embd // tensor_parallel +
|
||||
(n_embd // n_head) * 2)
|
||||
xtrt_llm_bloom.layers[i].input_layernorm.weight.value = (fromfile(
|
||||
dir_path, 'model.layers.' + str(i) + '.input_layernorm.weight.bin'))
|
||||
xtrt_llm_bloom.layers[i].input_layernorm.bias.value = (fromfile(
|
||||
dir_path, 'model.layers.' + str(i) + '.input_layernorm.bias.bin'))
|
||||
|
||||
t = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) +
|
||||
'.attention.query_key_value.weight.' + suffix,
|
||||
[n_embd, c_attn_out_dim], w_type)
|
||||
if t is not None:
|
||||
layer = xtrt_llm_bloom.layers[i].attention.qkv
|
||||
if use_smooth_quant:
|
||||
'''
|
||||
layer.weight.value = sq_trick(
|
||||
np.ascontiguousarray(np.transpose(t, [1, 0])))
|
||||
'''
|
||||
layer.weight.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
set_smoothquant_scale_factors(
|
||||
layer,
|
||||
xtrt_llm_bloom.layers[i].input_layernorm.scale_to_int,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.query_key_value.',
|
||||
[1, c_attn_out_dim],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
rank=rank,
|
||||
is_qkv=True)
|
||||
else:
|
||||
set_layer_weight(layer, np.transpose(t, [1, 0]), quant_mode)
|
||||
if bias:
|
||||
t = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) +
|
||||
'.attention.query_key_value.bias.' + str(rank) + '.bin')
|
||||
if t is not None:
|
||||
layer.bias.value = np.ascontiguousarray(t)
|
||||
|
||||
t = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
|
||||
[n_embd // tensor_parallel, n_embd], w_type)
|
||||
layer = xtrt_llm_bloom.layers[i].attention.dense
|
||||
if use_smooth_quant:
|
||||
'''
|
||||
layer.weight.value = sq_trick(
|
||||
np.ascontiguousarray(np.transpose(t, [1, 0])))
|
||||
'''
|
||||
layer.weight.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
dense_scale = getattr(xtrt_llm_bloom.layers[i].attention,
|
||||
"quantization_scaling_factor", None)
|
||||
set_smoothquant_scale_factors(
|
||||
layer, dense_scale, dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense.', [1, n_embd],
|
||||
quant_per_token_dyn, quant_per_channel)
|
||||
# set it to ones if dense layer is not applied smooth quant
|
||||
# layer.smoother.value = np.ones(
|
||||
# [1, n_embd // tensor_parallel], dtype=np.float32)
|
||||
# set it to the real smoother if dense layer is applied smooth quant
|
||||
set_smoother(layer, dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense',
|
||||
[1, n_embd // tensor_parallel], rank)
|
||||
else:
|
||||
set_layer_weight(layer, np.transpose(t, [1, 0]), quant_mode)
|
||||
if bias:
|
||||
layer.bias.value = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense.bias.bin')
|
||||
|
||||
dst = xtrt_llm_bloom.layers[i].post_layernorm.weight
|
||||
dst.value = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.post_attention_layernorm.weight.bin')
|
||||
dst = xtrt_llm_bloom.layers[i].post_layernorm.bias
|
||||
dst.value = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.post_attention_layernorm.bias.bin')
|
||||
|
||||
t = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_h_to_4h.weight.' + suffix,
|
||||
[n_embd, inter_size // tensor_parallel], w_type)
|
||||
layer = xtrt_llm_bloom.layers[i].mlp.fc
|
||||
if use_smooth_quant:
|
||||
'''
|
||||
layer.weight.value = sq_trick(
|
||||
np.ascontiguousarray(np.transpose(t, [1, 0])))
|
||||
'''
|
||||
layer.weight.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
set_smoothquant_scale_factors(
|
||||
layer,
|
||||
xtrt_llm_bloom.layers[i].post_layernorm.scale_to_int,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_h_to_4h.',
|
||||
[1, inter_size // tensor_parallel],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
rank=rank)
|
||||
else:
|
||||
set_layer_weight(layer, np.transpose(t, [1, 0]), quant_mode)
|
||||
if bias:
|
||||
layer.bias.value = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) +
|
||||
'.mlp.dense_h_to_4h.bias.' + str(rank) + '.bin')
|
||||
|
||||
t = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.weight.' + suffix,
|
||||
[inter_size // tensor_parallel, n_embd], w_type)
|
||||
layer = xtrt_llm_bloom.layers[i].mlp.proj
|
||||
if use_smooth_quant:
|
||||
'''
|
||||
layer.weight.value = sq_trick(
|
||||
np.ascontiguousarray(np.transpose(t, [1, 0])))
|
||||
'''
|
||||
layer.weight.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
proj_scale = getattr(xtrt_llm_bloom.layers[i].mlp,
|
||||
"quantization_scaling_factor", None)
|
||||
set_smoothquant_scale_factors(
|
||||
layer, proj_scale, dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.', [1, n_embd],
|
||||
quant_per_token_dyn, quant_per_channel)
|
||||
# set it to ones if proj layer is not applied smooth quant
|
||||
# layer.smoother.value = np.ones(
|
||||
# [1, inter_size // tensor_parallel], dtype=np.float32)
|
||||
# set it to the real smoother if proj layer is applied smooth quant
|
||||
set_smoother(layer, dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_4h_to_h',
|
||||
[1, inter_size // tensor_parallel], rank)
|
||||
else:
|
||||
set_layer_weight(layer, np.transpose(t, [1, 0]), quant_mode)
|
||||
if bias:
|
||||
layer.bias.value = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.dense_4h_to_h.bias.bin')
|
||||
|
||||
if use_int8_kv_cache:
|
||||
t = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) +
|
||||
'.attention.query_key_value.scale_y_quant_orig.bin', [1],
|
||||
np.float32)
|
||||
xtrt_llm_bloom.layers[
|
||||
i].attention.kv_orig_quant_scale.value = 1.0 / t
|
||||
xtrt_llm_bloom.layers[i].attention.kv_quant_orig_scale.value = t
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
Reference in New Issue
Block a user