add pkgs
This commit is contained in:
2
examples/llama/.gitignore
vendored
Normal file
2
examples/llama/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
llama*
|
||||
tokenizer.model
|
||||
183
examples/llama/README.md
Normal file
183
examples/llama/README.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# LLaMA
|
||||
|
||||
This document shows how to build and run a LLaMA model in XTRT-LLM on both single XPU and single node multi-XPU.
|
||||
|
||||
## Overview
|
||||
|
||||
The XTRT-LLM LLaMA example code is located in [`examples/llama`](./). There are several main files in that folder:
|
||||
|
||||
* [`build.py`](./build.py) to build the engine(s) needed to run the LLaMA model,
|
||||
* [`run.py`](./run.py) to run the inference on an input text,
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
* INT8 & INT4 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## Usage
|
||||
|
||||
The XTRT-LLM LLaMA example code locates at [examples/llama](./). It takes HF weights as input, and builds the corresponding XTRT engines. The number of XTRT engines depends on the number of XPUs used to run inference.
|
||||
|
||||
### Build XTRT engine(s)
|
||||
|
||||
Need to prepare the HF LLaMA checkpoint first by following the guides here https://huggingface.co/docs/transformers/main/en/model_doc/llama.
|
||||
|
||||
XTRT-LLM LLaMA builds XTRT engine(s) from HF checkpoint. If no checkpoint directory is specified, XTRT-LLM will build engine(s) with dummy weights.
|
||||
|
||||
Normally `build.py` only requires single XPU, but if you've already got all the XPUs needed while inferencing, you could enable parallelly building to make the engine building process faster by adding `--parallel_build` argument. Please note that currently `parallel_build` feature only supports single node.
|
||||
|
||||
Here're some examples:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# use_gpt_attention_plugin is necessary in LLaMA.
|
||||
# It is recommend to use --use_gpt_attention_plugin for better performance
|
||||
|
||||
# Build the LLaMA 7B model using a single XPU and FP16.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
|
||||
# Build the LLaMA 7B model using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/weight_only/1-XPU/
|
||||
|
||||
# Build LLaMA 7B using 2-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2 \
|
||||
--parallel_build
|
||||
|
||||
|
||||
# Build LLaMA 13B using 2-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama13b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama13b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2 \
|
||||
--parallel_build
|
||||
```
|
||||
|
||||
#### LLaMA v2 Updates
|
||||
The LLaMA v2 models with 7B and 13B are compatible with the LLaMA v1 implementation. The above
|
||||
commands still work.
|
||||
|
||||
|
||||
For LLaMA v2 70B, there is a restriction on tensor parallelism that the number of KV heads
|
||||
must be **divisible by the number of XPUs**. For example, since the 70B model has 8 KV heads, you can run it with
|
||||
2, 4 or 8 XPUs
|
||||
|
||||
|
||||
```bash
|
||||
# Build LLaMA 70B using 8-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama2-70b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama2-70b/trt_engines/fp16/8-XPU/ \
|
||||
--world_size 8 \
|
||||
--tp_size 8 \
|
||||
--parallel_build
|
||||
```
|
||||
|
||||
Same instructions can be applied to fine-tuned versions of the LLaMA v2 models (e.g. 7Bf or llama-2-7b-chat).
|
||||
|
||||
|
||||
Test with `summarize.py`: `pip install nltk rouge_score`
|
||||
|
||||
```bash
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
The smoothquant supports both LLaMA v1 and LLaMA v2. Unlike the FP16 build where the HF weights are processed and loaded into the XTRT-LLM directly, the SmoothQuant needs to load INT8 weights which should be pre-processed before building an engine.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
python3 hf_llama_convert.py -i ./downloads/llama-7b-hf -o ./downloads/smooth_llama_7B/sq0.8/ -sq 0.8 --tensor-parallelism 1 --storage-type fp16
|
||||
```
|
||||
|
||||
Note `hf_llama_convert.py` run with pytorch, and
|
||||
1. `torch-cpu` has better accuracy than XPyTorch generally.
|
||||
2. XPyTorch often use more than 32GB GM, thus more XPU are necessary to finish it.
|
||||
3. add `-p=1` if run with XPyTorch.
|
||||
|
||||
We offer converted data [here](https://fsh.bcebos.com/v1/klx-llm/pretrained_models/quantization/smooth_llama_7B.tar.gz) for LLaMa-7b with sq of 0.6.
|
||||
|
||||
[`build.py`](./build.py) add new options for the support of INT8 inference of SmoothQuant models.
|
||||
|
||||
`--use_smooth_quant` is the starting point of INT8 inference. By default, it
|
||||
will run the model in the _per-tensor_ mode.
|
||||
|
||||
`--per-token` and `--per-channel` are not supported yet.
|
||||
|
||||
Examples of build invocations:
|
||||
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --ft_model_dir=./downloads/smooth_llama_7B/sq0.8/1-XPU/ \
|
||||
--use_smooth_quant \
|
||||
--output_dir ./downloads/smooth_llama_7B/sq0.8/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
|
||||
Note we use `--ft_model_dir` instead of `--model_dir` and `--meta_ckpt_dir` since SmoothQuant model needs INT8 weights and various scales from the binary files.
|
||||
|
||||
### Run
|
||||
|
||||
Before running the examples, make sure set the environment variables:
|
||||
```
|
||||
export PYTORCH_NO_XPU_MEMORY_CACHING=0 # disable XPytorch cache XPU memory.
|
||||
export XMLIR_D_XPU_L3_SIZE=0 # disable XPytorch use L3.
|
||||
```
|
||||
If you are runing with multiple XPUs and no L3 space, you can set `BKCL_CCIX_BUFFER_GM=1` to disable L3.
|
||||
|
||||
|
||||
To run a XTRT-LLM LLaMA model using the engines generated by `build.py`
|
||||
|
||||
```bash
|
||||
# With fp16 inference
|
||||
python3 run.py --max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/llama-7b-hf/ \
|
||||
--engine_dir=./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
# With fp16 inference, SmoothQuant
|
||||
python3 run.py --max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/llama-7b-hf/ \
|
||||
--engine_dir=./downloads/smooth_llama_7B/sq0.8/trt_engines/fp16/1-XPU/
|
||||
|
||||
```
|
||||
|
||||
### Summarization using the LLaMA model
|
||||
|
||||
```bash
|
||||
# Run summarization using the LLaMA 7B model in FP16.
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Run summarization using the LLaMA 7B model quantized to INT8.
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/weight_only/1-XPU/
|
||||
|
||||
# Run summarization using the LLaMA 7B model in FP16 using two XPUs.
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
179
examples/llama/README_CN.md
Normal file
179
examples/llama/README_CN.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# LLaMA
|
||||
|
||||
本文档介绍了如何使用昆仑芯XTRT-LLM在单XPU和单节点多XPU上构建和运行LLaMA模型。
|
||||
|
||||
## 概述
|
||||
|
||||
XTRT-LLM LLMa示例代码位于 [`examples/llama`](./). 此文件夹中有以下几个主要文件:
|
||||
|
||||
* [`build.py`](./build.py) 构建运行LLaMa模型所需的XTRT引擎
|
||||
* [`run.py`](./run.py) 基于输入的文字进行推理
|
||||
|
||||
## 支持的矩阵
|
||||
|
||||
* FP16
|
||||
* INT8 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## 使用说明
|
||||
|
||||
XTRT-LLM LLaMa示例代码位于[examples/llama](./)。它使用HF权重作为输入,并且构建对应的XTRT引擎。XTRT引擎的数量取决于为了运行推理而是用的XPU个数。
|
||||
|
||||
### 构建XTRT引擎
|
||||
|
||||
需要先按照下面的指南准备HF LLaMA checkpoint:https://huggingface.co/docs/transformers/main/en/model_doc/llama。
|
||||
|
||||
XTRT-LLM LLaMA从HF checkpoint构建XTRT引擎。如果未指定checkpoint目录,XTRT-LLM将使用伪权重构建引擎。
|
||||
|
||||
通常 `build.py`只需要单个XPU,但如果您已经获得了推理所需的所有XPU,则可以通过添加 `--parallel_build` 参数来启用并行构建,从而加快引擎构建过程。请注意,目前`parallel_build`仅支持单个节点XPU。
|
||||
|
||||
以下是一些示例:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# use_gpt_attention_plugin is necessary in LLaMA.
|
||||
# It is recommend to use --use_gpt_attention_plugin for better performance
|
||||
|
||||
# Build the LLaMA 7B model using a single XPU and FP16.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
|
||||
# Build the LLaMA 7B model using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/weight_only/1-XPU/
|
||||
|
||||
# Build LLaMA 7B using 2-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama-7b-hf/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama-7b-hf/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2 \
|
||||
--parallel_build
|
||||
|
||||
|
||||
# Build LLaMA 13B using 2-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama13b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama13b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2 \
|
||||
--parallel_build
|
||||
```
|
||||
|
||||
#### LLaMA v2 更新
|
||||
|
||||
LLaMA v2-7B和13B模型与 LLaMA v1的实现是兼容的,以上命令仍然有效。
|
||||
|
||||
对于LLaMA v2 70B,张量并行性有一个限制,即KV heads的数量必须可以被XPU的数量整除。例如,由于70B模型有8个KV heads,您可以使用2、4或8个XPU运行它。
|
||||
|
||||
|
||||
```bash
|
||||
# Build LLaMA 70B using 8-way tensor parallelism.
|
||||
python build.py --model_dir ./downloads/llama2-70b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/llama2-70b/trt_engines/fp16/8-XPU/ \
|
||||
--world_size 8 \
|
||||
--tp_size 8 \
|
||||
--parallel_build
|
||||
```
|
||||
|
||||
相同的指令可以应用于LLaMA v2模型的微调版本(例如7Bf或LLaMA-2-7b-chat)。
|
||||
|
||||
使用`summarize.py`进行测试:`pip install nltk rouge_score`
|
||||
|
||||
```bash
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
SmoothQuant同时支持LLaMA v1和v2。与FP16的HF权重可以直接被处理并加载到XTRT-LLM不同,SmoothQuant需要加载INT8权重,而INT8权重在构建引擎之前需要进行预处理。
|
||||
|
||||
示例:
|
||||
```bash
|
||||
python3 hf_llama_convert.py -i ./downloads/llama-7b-hf -o ./downloads/smooth_llama_7B/sq0.8/ -sq 0.8 --tensor-parallelism 1 --storage-type fp16
|
||||
```
|
||||
|
||||
注意:使用PyTorch运行`hf_llama_convert.py`,并且
|
||||
1. 'torch-cpu' 通常比XPyTorch精度更高
|
||||
2. XPyTorch 通常使用超过32GB的GM,因此需要更多的XPU来完成它。
|
||||
3. 使用XPyTorch运行时,请添加`-p=1`。
|
||||
|
||||
为SmoothQuant 0.6的LLaMa 7B模型,我们提供这些[转换数据](https://fsh.bcebos.com/v1/klx-llm/pretrained_models/quantization/smooth_llama_7B.tar.gz):
|
||||
|
||||
`build.py`增加了新的选项来支持SmoothQuant模型的INT8推理。
|
||||
|
||||
`--use_smooth_quant` 是INT8推理的起点。默认情况下,它将以`--per-token`模式运行模型。
|
||||
`--per-token`和`--per-channel`目前还不支持。
|
||||
|
||||
构建调用实例:
|
||||
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --ft_model_dir=./downloads/smooth_llama_7B/sq0.8/1-XPU/ \
|
||||
--use_smooth_quant \
|
||||
--output_dir ./downloads/smooth_llama_7B/sq0.8/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
|
||||
注意:我们使用`--ft_model_dir`而不是`--model_dir`和`--meta_ckpt_dir`,因为SmoothQuant模型需要INT8权重和二进制文件中的各种scales。
|
||||
|
||||
### 运行
|
||||
|
||||
在运行示例之前,请确保设置环境变量:
|
||||
|
||||
```
|
||||
export PYTORCH_NO_XPU_MEMORY_CACHING=0 # disable XPytorch cache XPU memory.
|
||||
export XMLIR_D_XPU_L3_SIZE=0 # disable XPytorch use L3.
|
||||
```
|
||||
|
||||
如果使用多个XPU且没有L3空间运行,则可以通过设置`BKCL_CCIX_BUFFER_GM=1`以禁用L3。
|
||||
|
||||
使用`build.py`生成的引擎运行XTRT-LLM LLaMA模型:
|
||||
|
||||
```bash
|
||||
# With fp16 inference
|
||||
python3 run.py --max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/llama-7b-hf/ \
|
||||
--engine_dir=./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
# With fp16 inference, SmoothQuant
|
||||
python3 run.py --max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/llama-7b-hf/ \
|
||||
--engine_dir=./downloads/smooth_llama_7B/sq0.8/trt_engines/fp16/1-XPU/
|
||||
|
||||
```
|
||||
|
||||
### 使用LLaMA模型进行总结
|
||||
|
||||
```bash
|
||||
# Run summarization using the LLaMA 7B model in FP16.
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Run summarization using the LLaMA 7B model quantized to INT8.
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/weight_only/1-XPU/
|
||||
|
||||
# Run summarization using the LLaMA 7B model in FP16 using two XPUs.
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python summarize.py --test_trt_llm \
|
||||
--hf_model_location ./downloads/llama-7b-hf/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir ./downloads/llama-7b-hf/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
BIN
examples/llama/__pycache__/build.cpython-38.pyc
Normal file
BIN
examples/llama/__pycache__/build.cpython-38.pyc
Normal file
Binary file not shown.
BIN
examples/llama/__pycache__/weight.cpython-38.pyc
Normal file
BIN
examples/llama/__pycache__/weight.cpython-38.pyc
Normal file
Binary file not shown.
662
examples/llama/build.py
Normal file
662
examples/llama/build.py
Normal file
@@ -0,0 +1,662 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
from weight import (load_from_awq_llama, load_from_binary, load_from_gptq_llama,
|
||||
load_from_hf_llama, load_from_meta_llama)
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_xtrt
|
||||
from xtrt_llm.builder import Builder
|
||||
from xtrt_llm.layers.attention import PositionEmbeddingType
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import (smooth_quantize, weight_only_groupwise_quantize,
|
||||
weight_only_quantize)
|
||||
from xtrt_llm.network import net_guard
|
||||
from xtrt_llm.plugin.plugin import ContextFMHAType
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
from weight import parse_ft_config # isort:skip
|
||||
|
||||
MODEL_NAME = "llama"
|
||||
|
||||
# 2 routines: get_engine_name, serialize_engine
|
||||
# are direct copy from gpt example, TODO: put in utils?
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, pp_size, rank):
|
||||
if pp_size == 1:
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
|
||||
pp_size, rank)
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
engine.serialize(path)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--world_size', type=int, default=1)
|
||||
parser.add_argument('--tp_size', type=int, default=1)
|
||||
parser.add_argument('--pp_size', type=int, default=1)
|
||||
parser.add_argument('--model_dir', type=str, default=None)
|
||||
parser.add_argument('--ft_model_dir', type=str, default=None)
|
||||
parser.add_argument('--meta_ckpt_dir', type=str, default=None)
|
||||
parser.add_argument('--quant_ckpt_path', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
# choices=['float32', 'bfloat16', 'float16'])
|
||||
choices=['float32', 'float16'])
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
default='model.cache',
|
||||
help=
|
||||
'The path of to read timing cache from, will be ignored if the file does not exist'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opt_memory_use',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Whether to use Host memory optimization for building engine')
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--vocab_size', type=int, default=32000)
|
||||
parser.add_argument('--n_layer', type=int, default=32)
|
||||
parser.add_argument('--n_positions', type=int, default=2048)
|
||||
parser.add_argument('--n_embd', type=int, default=4096)
|
||||
parser.add_argument('--n_head', type=int, default=32)
|
||||
parser.add_argument('--n_kv_head', type=int, default=None)
|
||||
parser.add_argument('--multiple_of', type=int, default=256)
|
||||
parser.add_argument('--ffn_dim_multiplier', type=float, default=1.0)
|
||||
parser.add_argument('--inter_size', type=int, default=None)
|
||||
parser.add_argument('--hidden_act', type=str, default='silu')
|
||||
parser.add_argument('--rms_norm_eps', type=float, default=1e-06)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--max_input_len', type=int, default=2048)
|
||||
parser.add_argument('--max_output_len', type=int, default=512)
|
||||
parser.add_argument('--max_beam_width', type=int, default=1)
|
||||
parser.add_argument('--rotary_base', type=float, default=10000.0)
|
||||
parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--use_gpt_attention_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
# choices=['float16', 'bfloat16', 'float32'])
|
||||
choices=['float32', 'float16'])
|
||||
parser.add_argument(
|
||||
'--use_gemm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
# choices=['float16', 'bfloat16', 'float32'])
|
||||
choices=['float32', 'float16'])
|
||||
|
||||
parser.add_argument(
|
||||
'--use_rmsnorm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default=False,
|
||||
# choices=['float16', 'float32', 'bfloat16'])
|
||||
choices=['float32', 'float16'])
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='llama_outputs',
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
parser.add_argument('--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
# Arguments related to the quantization of the model.
|
||||
parser.add_argument(
|
||||
'--use_smooth_quant',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
|
||||
'See --per_channel and --per_token for finer-grained quantization options.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.')
|
||||
parser.add_argument('--group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size used in GPTQ/AWQ quantization.')
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=1, # Meta does TP on hidden dim
|
||||
choices=[0, 1],
|
||||
help=
|
||||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int16', 'int8', 'int4', 'int4_awq', 'int4_gptq'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_inflight_batching',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.")
|
||||
parser.add_argument(
|
||||
'--paged_kv_cache',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
|
||||
)
|
||||
parser.add_argument('--tokens_per_block',
|
||||
type=int,
|
||||
default=64,
|
||||
help='Number of tokens per block in paged KV cache')
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine')
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_custom_all_reduce',
|
||||
action='store_true',
|
||||
help=
|
||||
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
|
||||
parser.add_argument('--gather_all_token_logits',
|
||||
action='store_true',
|
||||
default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
assert not (
|
||||
args.use_smooth_quant and args.use_weight_only
|
||||
), "You cannot enable both SmoothQuant and INT8 weight-only together."
|
||||
|
||||
if not args.remove_input_padding:
|
||||
if args.use_gpt_attention_plugin:
|
||||
logger.warning(
|
||||
f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
|
||||
)
|
||||
|
||||
if args.use_inflight_batching:
|
||||
if not args.use_gpt_attention_plugin:
|
||||
args.use_gpt_attention_plugin = 'float16'
|
||||
logger.info(
|
||||
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
|
||||
)
|
||||
if not args.remove_input_padding:
|
||||
args.remove_input_padding = True
|
||||
logger.info(
|
||||
"Using remove input padding for inflight batching mode.")
|
||||
if not args.paged_kv_cache:
|
||||
args.paged_kv_cache = True
|
||||
logger.info("Using paged KV cache for inflight batching mode.")
|
||||
|
||||
if args.use_smooth_quant:
|
||||
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
|
||||
args.per_channel)
|
||||
elif args.use_weight_only:
|
||||
if args.per_group:
|
||||
args.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
else:
|
||||
args.quant_mode = QuantMode.use_weight_only(
|
||||
args.weight_only_precision == 'int4')
|
||||
else:
|
||||
args.quant_mode = QuantMode(0)
|
||||
|
||||
if args.int8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_int8_kv_cache()
|
||||
|
||||
if args.rotary_scaling is not None:
|
||||
rotary_scaling = {
|
||||
"type": args.rotary_scaling[0],
|
||||
"factor": float(args.rotary_scaling[1])
|
||||
}
|
||||
assert rotary_scaling["type"] in ["linear", "dynamic"]
|
||||
assert rotary_scaling["factor"] > 1.0
|
||||
args.rotary_scaling = rotary_scaling
|
||||
if rotary_scaling["type"] == "dynamic":
|
||||
assert not args.remove_input_padding, "TODO: Not supported yet"
|
||||
|
||||
# Since gpt_attenttion_plugin is the only way to apply RoPE now,
|
||||
# force use the plugin for now with the correct data type.
|
||||
args.use_gpt_attention_plugin = args.dtype
|
||||
if args.model_dir is not None:
|
||||
hf_config = LlamaConfig.from_pretrained(args.model_dir)
|
||||
args.inter_size = hf_config.intermediate_size # override the inter_size for LLaMA
|
||||
args.n_embd = hf_config.hidden_size
|
||||
args.n_head = hf_config.num_attention_heads
|
||||
if hasattr(hf_config, "num_key_value_heads"):
|
||||
args.n_kv_head = hf_config.num_key_value_heads
|
||||
args.n_layer = hf_config.num_hidden_layers
|
||||
args.n_positions = hf_config.max_position_embeddings
|
||||
args.vocab_size = hf_config.vocab_size
|
||||
args.hidden_act = hf_config.hidden_act
|
||||
args.rms_norm_eps = hf_config.rms_norm_eps
|
||||
elif args.meta_ckpt_dir is not None:
|
||||
with open(Path(args.meta_ckpt_dir, "params.json")) as fp:
|
||||
meta_config: dict = json.load(fp)
|
||||
args.n_embd = meta_config["dim"]
|
||||
args.n_head = meta_config["n_heads"]
|
||||
args.n_layer = meta_config["n_layers"]
|
||||
args.n_kv_head = meta_config.get("n_kv_heads", args.n_head)
|
||||
args.multiple_of = meta_config["multiple_of"]
|
||||
args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1)
|
||||
n_embd = int(4 * args.n_embd * 2 / 3)
|
||||
args.inter_size = args.multiple_of * (
|
||||
(int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) //
|
||||
args.multiple_of)
|
||||
args.rms_norm_eps = meta_config["norm_eps"]
|
||||
elif args.ft_model_dir is not None:
|
||||
n_embd, n_head, n_layer, n_positions, vocab_size, hidden_act, inter_size, n_kv_head = parse_ft_config(
|
||||
Path(args.ft_model_dir) / "config.ini")
|
||||
args.inter_size = inter_size # override the inter_size for LLaMA
|
||||
args.n_kv_head = n_kv_head
|
||||
args.n_embd = n_embd
|
||||
args.n_head = n_head
|
||||
args.n_layer = n_layer
|
||||
args.n_positions = n_positions
|
||||
args.vocab_size = vocab_size
|
||||
args.hidden_act = hidden_act
|
||||
args.rms_norm_eps = 1e-06
|
||||
logger.warning("Set rms_norm_eps to 1e-06 directly.")
|
||||
assert args.use_gpt_attention_plugin, "LLaMa must use gpt attention plugin"
|
||||
if args.n_kv_head is None:
|
||||
args.n_kv_head = args.n_head
|
||||
elif args.n_kv_head != args.n_head:
|
||||
assert (args.n_head % args.n_kv_head) == 0, \
|
||||
"MQA/GQA requires the number of heads to be divisible by the number of K/V heads."
|
||||
assert (args.n_kv_head % args.tp_size) == 0 or (args.tp_size % args.n_kv_head) == 0, \
|
||||
"MQA/GQA requires either the number of K/V heads to be divisible by the tensor parallelism size OR " \
|
||||
"the tensor parallelism size to be divisible by the number of K/V heads."
|
||||
|
||||
# if args.dtype == 'bfloat16':
|
||||
# assert args.use_gemm_plugin, "Please use gemm plugin when dtype is bfloat16"
|
||||
|
||||
assert args.pp_size * args.tp_size == args.world_size
|
||||
|
||||
if args.max_num_tokens is not None:
|
||||
assert args.enable_context_fmha
|
||||
|
||||
if args.inter_size is None:
|
||||
# this should not be need when loading a real model
|
||||
# but it is helpful when creating a dummy model without loading any real weights
|
||||
n_embd = int(4 * args.n_embd * 2 / 3)
|
||||
args.inter_size = args.multiple_of * (
|
||||
(int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) //
|
||||
args.multiple_of)
|
||||
logger.info(f"Setting inter_size to {args.inter_size}.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_rank_engine(builder: Builder,
|
||||
builder_config: xtrt_llm.builder.BuilderConfig,
|
||||
engine_name, rank, args):
|
||||
'''
|
||||
@brief: Build the engine on the given rank.
|
||||
@param rank: The rank to build the engine.
|
||||
@param args: The cmd line arguments.
|
||||
@return: The built engine.
|
||||
'''
|
||||
dtype = str_dtype_to_xtrt(args.dtype)
|
||||
mapping = Mapping(world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size)
|
||||
|
||||
assert args.n_layer % args.pp_size == 0, \
|
||||
f"num_layers {args.n_layer} must be a multiple of pipeline parallelism size {args.pp_size}"
|
||||
|
||||
# Initialize Module
|
||||
xtrt_llm_llama = xtrt_llm.models.LLaMAForCausalLM(
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
num_kv_heads=args.n_kv_head,
|
||||
hidden_size=args.n_embd,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
dtype=dtype,
|
||||
mlp_hidden_size=args.inter_size,
|
||||
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
|
||||
mapping=mapping,
|
||||
rotary_base=args.rotary_base,
|
||||
rotary_scaling=args.rotary_scaling,
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
embedding_sharding_dim=args.embedding_sharding_dim,
|
||||
quant_mode=args.quant_mode,
|
||||
rms_norm_eps=args.rms_norm_eps,
|
||||
gather_all_token_logits=args.gather_all_token_logits)
|
||||
if args.use_smooth_quant:
|
||||
xtrt_llm_llama = smooth_quantize(xtrt_llm_llama, args.quant_mode)
|
||||
elif args.use_weight_only:
|
||||
if args.weight_only_precision == 'int8' or args.weight_only_precision == 'int16':
|
||||
'''
|
||||
xtrt_llm_llama = weight_only_quantize(xtrt_llm_llama,
|
||||
args.quant_mode)
|
||||
'''
|
||||
elif args.weight_only_precision == 'int4':
|
||||
'''
|
||||
xtrt_llm_llama = weight_only_quantize(xtrt_llm_llama,
|
||||
args.quant_mode)
|
||||
'''
|
||||
elif args.weight_only_precision == 'int4_awq':
|
||||
xtrt_llm_llama = weight_only_groupwise_quantize(
|
||||
model=xtrt_llm_llama,
|
||||
quant_mode=args.quant_mode,
|
||||
group_size=args.group_size,
|
||||
zero=False,
|
||||
pre_quant_scale=True,
|
||||
exclude_modules=[])
|
||||
elif args.weight_only_precision == 'int4_gptq':
|
||||
xtrt_llm_llama = weight_only_groupwise_quantize(
|
||||
model=xtrt_llm_llama,
|
||||
quant_mode=args.quant_mode,
|
||||
group_size=args.group_size,
|
||||
zero=True,
|
||||
pre_quant_scale=False)
|
||||
|
||||
if args.per_group:
|
||||
load_func = load_from_awq_llama if args.weight_only_precision == 'int4_awq' else load_from_gptq_llama
|
||||
load_func(xtrt_llm_llama=xtrt_llm_llama,
|
||||
quant_ckpt_path=args.quant_ckpt_path,
|
||||
mapping=mapping,
|
||||
dtype=args.dtype)
|
||||
elif args.meta_ckpt_dir is not None:
|
||||
load_from_meta_llama(xtrt_llm_llama, args.meta_ckpt_dir, mapping,
|
||||
args.dtype)
|
||||
elif args.model_dir is not None:
|
||||
logger.info(f'Loading HF LLaMA ... from {args.model_dir}')
|
||||
tik = time.time()
|
||||
hf_llama = LlamaForCausalLM.from_pretrained(
|
||||
args.model_dir,
|
||||
device_map={
|
||||
"model": "cpu",
|
||||
"lm_head": "cpu"
|
||||
}, # Load to CPU memory
|
||||
torch_dtype="auto")
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'HF LLaMA loaded. Total time: {t}')
|
||||
load_from_hf_llama(xtrt_llm_llama,
|
||||
hf_llama,
|
||||
mapping=mapping,
|
||||
dtype=args.dtype)
|
||||
del hf_llama
|
||||
elif args.ft_model_dir is not None:
|
||||
load_from_binary(xtrt_llm_llama,
|
||||
args.ft_model_dir,
|
||||
mapping,
|
||||
fp16=(args.dtype == 'float16'),
|
||||
multi_query_mode=(args.n_kv_head != args.n_head))
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
if args.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
if args.use_rmsnorm_plugin:
|
||||
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
|
||||
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.use_weight_only:
|
||||
if args.per_group:
|
||||
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
else:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
args.use_custom_all_reduce)
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
if args.paged_kv_cache:
|
||||
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
|
||||
if args.quant_mode.is_weight_only():
|
||||
builder_config.trt_builder_config.use_weight_only = args.weight_only_precision
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(xtrt_llm_llama.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = xtrt_llm_llama.prepare_inputs(args.max_batch_size,
|
||||
args.max_input_len,
|
||||
args.max_output_len, True,
|
||||
args.max_beam_width,
|
||||
args.max_num_tokens)
|
||||
xtrt_llm_llama(*inputs)
|
||||
if args.enable_debug_output:
|
||||
# mark intermediate nodes' outputs
|
||||
for k, v in xtrt_llm_llama.named_network_outputs():
|
||||
v = v.trt_tensor
|
||||
v.name = k
|
||||
network.trt_network.mark_output(v)
|
||||
v.dtype = dtype
|
||||
|
||||
# xtrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
engine = None
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config, compiler="gr")
|
||||
if rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
if args.opt_memory_use:
|
||||
return engine, network
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
# torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
logger.set_level(args.log_level)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
# when doing serializing build, all ranks share one engine
|
||||
builder = Builder()
|
||||
|
||||
cache = None
|
||||
for cur_rank in range(args.world_size):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
|
||||
int8_trt_flag = args.quant_mode.has_act_and_weight_quant() or (
|
||||
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
|
||||
builder_config = builder.create_builder_config(
|
||||
name=MODEL_NAME,
|
||||
precision=args.dtype,
|
||||
timing_cache=args.timing_cache if cache is None else cache,
|
||||
tensor_parallel=args.tp_size,
|
||||
pipeline_parallel=args.pp_size,
|
||||
parallel_build=args.parallel_build,
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
num_kv_heads=args.n_kv_head,
|
||||
hidden_size=args.n_embd,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
inter_size = args.inter_size,
|
||||
max_position_embeddings=args.n_positions,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
int8=int8_trt_flag,
|
||||
fp8=False,
|
||||
quant_mode=args.quant_mode,
|
||||
strongly_typed=args.strongly_typed,
|
||||
opt_level=args.builder_opt,
|
||||
fusion_pattern_list=["remove_dup_mask"],
|
||||
gather_all_token_logits=args.gather_all_token_logits)
|
||||
guard = xtrt_llm.fusion_patterns.FuseonPatternGuard()
|
||||
print(guard)
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.tp_size,
|
||||
args.pp_size, cur_rank)
|
||||
if args.opt_memory_use:
|
||||
engine, network = build_rank_engine(builder, builder_config,
|
||||
engine_name, cur_rank, args)
|
||||
else:
|
||||
engine = build_rank_engine(builder, builder_config, engine_name,
|
||||
cur_rank, args)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
# if cur_rank == 0:
|
||||
# # Use in-memory timing cache for multiple builder passes.
|
||||
# if not args.parallel_build:
|
||||
# cache = builder_config.trt_builder_config.get_timing_cache()
|
||||
|
||||
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
|
||||
del engine
|
||||
if args.opt_memory_use:
|
||||
network.__del__()
|
||||
|
||||
# if rank == 0:
|
||||
# ok = builder.save_timing_cache(
|
||||
# builder_config, os.path.join(args.output_dir, "model.cache"))
|
||||
# assert ok, "Failed to save timing cache."
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
tik = time.time()
|
||||
if args.parallel_build and args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
logger.info('Serially build TensorRT engines.')
|
||||
build(0, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Total time of building all {args.world_size} engines: {t}')
|
||||
313
examples/llama/convert.py
Normal file
313
examples/llama/convert.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for exporting a model to our custom format.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def save_val(val, dir, key, tp_num=None):
|
||||
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
|
||||
val.tofile(dir / f"model.{key}.{suffix}")
|
||||
|
||||
|
||||
def save_split(split_vals, dir, key, i, factor):
|
||||
for j, val in enumerate(split_vals):
|
||||
save_val(val, dir, key, i * factor + j)
|
||||
|
||||
|
||||
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
|
||||
"""
|
||||
This function has two purposes:
|
||||
- compute quantized weights, scaled either per-tensor or per-column
|
||||
- compute scaling factors
|
||||
|
||||
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
|
||||
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
|
||||
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
|
||||
|
||||
Here is the list of what we need (T means per-tensor, C per-column):
|
||||
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
|
||||
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
|
||||
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
|
||||
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
|
||||
to quant range (int8) (used for CUBLAS) (T, C)
|
||||
|
||||
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
|
||||
but then the model would change depending on the number of GPUs used.
|
||||
|
||||
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
|
||||
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
|
||||
For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns.
|
||||
"""
|
||||
|
||||
# compute weight scaling factors for fp->int8 and int8->fp
|
||||
if is_qkv and not multi_query_mode:
|
||||
scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max(
|
||||
dim=-1, keepdims=True)[0].cpu().numpy()
|
||||
scale_w_orig_quant_c = 127. / act_range["w"].reshape(3,
|
||||
-1).cpu().numpy()
|
||||
elif is_qkv and multi_query_mode:
|
||||
hidden_dim = weights.shape[0]
|
||||
local_dim = act_range["w"].shape[0]
|
||||
kv_dim = (local_dim - hidden_dim) // 2
|
||||
scale_w_q = act_range["w"][0:hidden_dim]
|
||||
scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim]
|
||||
scale_w_v = act_range["w"][-kv_dim:]
|
||||
|
||||
scale_w_qkv_t = torch.concat([
|
||||
scale_w_q.max(dim=0, keepdim=True)[0],
|
||||
scale_w_k.max(dim=0, keepdim=True)[0],
|
||||
scale_w_v.max(dim=0, keepdim=True)[0]
|
||||
])
|
||||
|
||||
scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy()
|
||||
scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy()
|
||||
else:
|
||||
scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy()
|
||||
scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy()
|
||||
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
|
||||
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
|
||||
|
||||
# compute the rest of needed scaling factors
|
||||
scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item())
|
||||
scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item())
|
||||
scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.)
|
||||
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_t)
|
||||
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_c)
|
||||
if is_qkv and not multi_query_mode:
|
||||
scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
if is_qkv and multi_query_mode:
|
||||
scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0],
|
||||
scale_w_q.shape)
|
||||
scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1],
|
||||
scale_w_k.shape)
|
||||
scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2],
|
||||
scale_w_v.shape)
|
||||
scale_y_accum_quant_t = np.concatenate(
|
||||
[scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t])
|
||||
scale_w_quant_orig_t = np.concatenate([
|
||||
np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape),
|
||||
np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape),
|
||||
np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape)
|
||||
])
|
||||
|
||||
to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8)
|
||||
|
||||
if is_qkv and multi_query_mode:
|
||||
scale_w_quant_orig_t_expand = np.ones([weights.shape[-1]])
|
||||
scale_w_quant_orig_t_expand[:hidden_dim] = scale_w_quant_orig_t[0]
|
||||
scale_w_quant_orig_t_expand[hidden_dim:hidden_dim +
|
||||
kv_dim] = scale_w_quant_orig_t[1]
|
||||
scale_w_quant_orig_t_expand[-kv_dim:] = scale_w_quant_orig_t[2]
|
||||
weight_int8 = to_i8(weights * scale_w_quant_orig_t_expand)
|
||||
else:
|
||||
weight_int8 = to_i8(weights * scale_w_orig_quant_t)
|
||||
return {
|
||||
"weight.int8": weight_int8,
|
||||
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
|
||||
"scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32),
|
||||
"scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32),
|
||||
"scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32),
|
||||
"scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32),
|
||||
"scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32),
|
||||
"scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def save_multi_query_mode_qkv_int8(val, dir, base_key, saved_key, factor, rank,
|
||||
local_dim, head_size):
|
||||
q, k, v = np.split(val, [local_dim, local_dim + head_size], axis=-1)
|
||||
q_split = np.split(q, factor, axis=-1)
|
||||
k_split = np.split(k, factor, axis=-1)
|
||||
v_split = np.split(v, factor, axis=-1)
|
||||
split_vals = [
|
||||
np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
|
||||
for ii in range(factor)
|
||||
]
|
||||
save_split(split_vals, dir, f"{base_key}.{saved_key}", rank, factor)
|
||||
|
||||
|
||||
def write_int8(vals,
|
||||
dir,
|
||||
base_key,
|
||||
split_dim,
|
||||
i,
|
||||
factor,
|
||||
is_qkv=False,
|
||||
multi_query_mode=False):
|
||||
saved_keys_once = [
|
||||
"scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant",
|
||||
"scale_y_quant_orig"
|
||||
]
|
||||
|
||||
if is_qkv and multi_query_mode:
|
||||
assert split_dim == -1
|
||||
local_dim = vals["weight.int8"].shape[0]
|
||||
head_size = (vals["weight.int8"].shape[1] - local_dim) // 2
|
||||
|
||||
save_multi_query_mode_qkv_int8(vals["weight.int8"], dir, base_key,
|
||||
"weight.int8", factor, i, local_dim,
|
||||
head_size)
|
||||
save_multi_query_mode_qkv_int8(vals["weight.int8.col"], dir, base_key,
|
||||
"weight.int8.col", factor, i, local_dim,
|
||||
head_size)
|
||||
save_multi_query_mode_qkv_int8(vals["scale_w_quant_orig.col"], dir,
|
||||
base_key, "scale_w_quant_orig.col",
|
||||
factor, i, local_dim, head_size)
|
||||
save_multi_query_mode_qkv_int8(vals["scale_y_accum_quant.col"], dir,
|
||||
base_key, "scale_y_accum_quant.col",
|
||||
factor, i, local_dim, head_size)
|
||||
save_multi_query_mode_qkv_int8(vals["scale_w_quant_orig"], dir,
|
||||
base_key, "scale_w_quant_orig", factor,
|
||||
i, local_dim, head_size)
|
||||
save_multi_query_mode_qkv_int8(vals["scale_y_accum_quant"], dir,
|
||||
base_key, "scale_y_accum_quant", factor,
|
||||
i, local_dim, head_size)
|
||||
saved_keys_once = ["scale_x_orig_quant", "scale_y_quant_orig"]
|
||||
else:
|
||||
save_split(np.split(vals["weight.int8"], factor, axis=split_dim), dir,
|
||||
f"{base_key}.weight.int8", i, factor)
|
||||
save_split(np.split(vals["weight.int8.col"], factor, axis=split_dim),
|
||||
dir, f"{base_key}.weight.int8.col", i, factor)
|
||||
|
||||
if split_dim == -1:
|
||||
save_split(
|
||||
np.split(vals["scale_w_quant_orig.col"], factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_w_quant_orig.col", i, factor)
|
||||
save_split(
|
||||
np.split(vals["scale_y_accum_quant.col"],
|
||||
factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_y_accum_quant.col", i, factor)
|
||||
if is_qkv:
|
||||
save_split(
|
||||
np.split(vals["scale_y_accum_quant"],
|
||||
factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_y_accum_quant", i, factor)
|
||||
save_split(
|
||||
np.split(vals["scale_w_quant_orig"], factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_w_quant_orig", i, factor)
|
||||
saved_keys_once = ["scale_x_orig_quant", "scale_y_quant_orig"]
|
||||
else:
|
||||
saved_keys_once += [
|
||||
"scale_w_quant_orig.col", "scale_y_accum_quant.col"
|
||||
]
|
||||
|
||||
if i == 0:
|
||||
for save_key in saved_keys_once:
|
||||
save_val(vals[save_key], dir, f"{base_key}.{save_key}")
|
||||
|
||||
|
||||
def str_to_np_dtype(type_str):
|
||||
convert_dict = {
|
||||
"fp32": np.float32,
|
||||
"fp16": np.float16,
|
||||
}
|
||||
dtype = convert_dict.get(type_str)
|
||||
if dtype is None:
|
||||
raise ValueError(f"{type_str} is an invalid storage type")
|
||||
return dtype
|
||||
|
||||
|
||||
def split_and_save_weight(i, saved_dir, factor, key, val, act_range, config):
|
||||
# The split_factor indicates the number of ranks to implement
|
||||
# distributed GEMMs. For Tensor Parallelism, each rank/GPU works
|
||||
# on split_hidden_dim // split_factor channels.
|
||||
|
||||
int8_outputs = config.get("int8_outputs", None)
|
||||
multi_query_mode = config.get("multi_query_mode", False)
|
||||
local_dim = config.get("local_dim", None)
|
||||
|
||||
save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"
|
||||
|
||||
if "input_layernorm.weight" in key or "input_layernorm.bias" in key or \
|
||||
"attention.dense.bias" in key or "post_layernorm.weight" in key or \
|
||||
"post_attention_layernorm.bias" in key or "mlp.dense_4h_to_h.bias" in key or \
|
||||
"final_layernorm.weight" in key or "final_layernorm.bias" in key:
|
||||
|
||||
# shared weights, only need to convert the weights of rank 0
|
||||
if i == 0:
|
||||
save_val(val, saved_dir, key)
|
||||
|
||||
elif "attention.dense.weight" in key or "mlp.proj.weight" in key:
|
||||
split_dim = 0
|
||||
split_vals = np.split(val, factor, axis=split_dim)
|
||||
save_split(split_vals, saved_dir, key, i, factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val, act_range)
|
||||
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
|
||||
|
||||
elif "mlp.fc.weight" in key or "mlp.gate.weight" in key:
|
||||
split_dim = -1
|
||||
split_vals = np.split(val, factor, axis=split_dim)
|
||||
save_split(split_vals, saved_dir, key, i, factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val, act_range)
|
||||
write_int8(vals_i8, saved_dir, base_key, split_dim, i, factor)
|
||||
|
||||
elif "attention.query_key_value.weight" in key:
|
||||
hidden_dim = val.shape[0]
|
||||
if local_dim is None:
|
||||
local_dim = val.shape[-1] // 3
|
||||
if multi_query_mode:
|
||||
head_size = (val.shape[-1] - local_dim) // 2
|
||||
val = val.reshape(hidden_dim, local_dim + 2 * head_size)
|
||||
w_q, w_k, w_v = np.split(val, [local_dim, local_dim + head_size],
|
||||
axis=-1)
|
||||
w_q_split = np.split(w_q, factor, axis=-1)
|
||||
w_k_split = np.split(w_k, factor, axis=-1)
|
||||
w_v_split = np.split(w_v, factor, axis=-1)
|
||||
split_vals = [
|
||||
np.concatenate((w_q_split[ii], w_k_split[ii], w_v_split[ii]),
|
||||
axis=-1) for ii in range(factor)
|
||||
]
|
||||
split_dim = -1
|
||||
else:
|
||||
val = val.reshape(hidden_dim, 3, local_dim)
|
||||
split_dim = -1
|
||||
split_vals = np.split(val, factor, axis=split_dim)
|
||||
save_split(split_vals, saved_dir, key, i, factor)
|
||||
if save_int8:
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
is_qkv=True,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8,
|
||||
saved_dir,
|
||||
base_key,
|
||||
split_dim,
|
||||
i,
|
||||
factor,
|
||||
is_qkv=True,
|
||||
multi_query_mode=multi_query_mode)
|
||||
elif "attention.dense.smoother" in key or "mlp.proj.smoother" in key:
|
||||
split_vals = np.split(val, factor, axis=0)
|
||||
save_split(split_vals, saved_dir, key, i, factor)
|
||||
|
||||
else:
|
||||
print(f"[WARNING] {key} not handled by converter")
|
||||
335
examples/llama/hf_llama_convert.py
Normal file
335
examples/llama/hf_llama_convert.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Convert huggingface GPT model. Use https://huggingface.co/gpt2 as demo.
|
||||
'''
|
||||
import argparse
|
||||
import configparser
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as multiprocessing
|
||||
from convert import split_and_save_weight, str_to_np_dtype
|
||||
from smoothquant import (capture_activation_range, smooth_gemm,
|
||||
smooth_gemm_fc1_gate)
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
|
||||
def merge_qkv_scales(q_name, hf_model, scales, llama_qkv_para):
|
||||
layer_name_q = q_name.replace(".weight", "")
|
||||
layer_name_k = layer_name_q.replace("q_proj", "k_proj")
|
||||
layer_name_v = layer_name_q.replace("q_proj", "v_proj")
|
||||
layer_name_qkv = layer_name_q.replace("q_proj", "qkv_proj")
|
||||
|
||||
q = hf_model.state_dict()[layer_name_q + ".weight"]
|
||||
k = hf_model.state_dict()[layer_name_k + ".weight"]
|
||||
v = hf_model.state_dict()[layer_name_v + ".weight"]
|
||||
|
||||
weight = torch.cat([q, k, v], dim=0)
|
||||
|
||||
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"]
|
||||
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
|
||||
print(scales[layer_name_q])
|
||||
scales[layer_name_qkv]["y"] = torch.cat([
|
||||
scales[layer_name_q]["y"], scales[layer_name_k]["y"],
|
||||
scales[layer_name_v]["y"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
|
||||
# Smooth the activation and weights with smoother = $\diag{s}$
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, LlamaDecoderLayer):
|
||||
continue
|
||||
# qkv_proj
|
||||
layer_name_q = name + ".self_attn.q_proj"
|
||||
layer_name_k = name + ".self_attn.k_proj"
|
||||
layer_name_v = name + ".self_attn.v_proj"
|
||||
layer_name_qkv = name + ".self_attn.qkv_proj"
|
||||
|
||||
weight = torch.cat([
|
||||
module.self_attn.q_proj.weight, module.self_attn.k_proj.weight,
|
||||
module.self_attn.v_proj.weight
|
||||
],
|
||||
dim=0)
|
||||
|
||||
smoother = smooth_gemm(weight, scales[layer_name_q]["x"],
|
||||
module.input_layernorm.weight, None, alpha)
|
||||
|
||||
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
|
||||
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
|
||||
scales[layer_name_qkv]["y"] = torch.cat([
|
||||
scales[layer_name_q]["y"], scales[layer_name_k]["y"],
|
||||
scales[layer_name_v]["y"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
# see transpose_weights function
|
||||
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
|
||||
|
||||
# =================================================================
|
||||
layer_name = name + ".self_attn.o_proj"
|
||||
smoother = smooth_gemm(module.self_attn.o_proj.weight,
|
||||
scales[layer_name]["x"], None, None, alpha)
|
||||
llama_smoother[layer_name] = smoother.float()
|
||||
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
# ==================================================================
|
||||
fc1_layer_name = name + ".mlp.gate_proj"
|
||||
gate_layer_name = name + ".mlp.up_proj"
|
||||
|
||||
smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight,
|
||||
module.mlp.up_proj.weight,
|
||||
scales[fc1_layer_name]["x"],
|
||||
module.post_attention_layernorm.weight,
|
||||
None, alpha)
|
||||
|
||||
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
|
||||
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
|
||||
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
# ==================================================================
|
||||
layer_name = name + ".mlp.down_proj"
|
||||
smoother = smooth_gemm(module.mlp.down_proj.weight,
|
||||
scales[layer_name]["x"], None, None, alpha)
|
||||
llama_smoother[layer_name] = smoother.float()
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(
|
||||
dim=1)[0]
|
||||
|
||||
|
||||
def gpt_to_ft_name(orig_name):
|
||||
global_ft_weights = {
|
||||
"model.embed_tokens.weight": 'vocab_embedding.weight',
|
||||
"model.norm.weight": 'ln_f.weight',
|
||||
"lm_head.weight": 'lm_head.weight',
|
||||
}
|
||||
|
||||
if orig_name in global_ft_weights:
|
||||
return global_ft_weights[orig_name]
|
||||
|
||||
_, _, layer_id, *weight_name = orig_name.split(".")
|
||||
|
||||
layer_id = int(layer_id)
|
||||
weight_name = ".".join(weight_name)
|
||||
|
||||
if weight_name == 'self_attn.q_proj.weight':
|
||||
return f"layers.{layer_id}.attention.query_key_value.weight"
|
||||
elif weight_name == 'self_attn.k_proj.weight' or weight_name == 'self_attn.v_proj.weight':
|
||||
return f"layers.{layer_id}.attention.kv.weight"
|
||||
|
||||
per_layer_weights = {
|
||||
"input_layernorm.weight": "input_layernorm.weight",
|
||||
"self_attn.o_proj.weight": "attention.dense.weight",
|
||||
"mlp.gate_proj.weight": "mlp.fc.weight",
|
||||
"mlp.down_proj.weight": "mlp.proj.weight",
|
||||
"mlp.up_proj.weight": "mlp.gate.weight",
|
||||
"post_attention_layernorm.weight": "post_layernorm.weight",
|
||||
}
|
||||
|
||||
return f"layers.{layer_id}.{per_layer_weights[weight_name]}"
|
||||
|
||||
|
||||
# LLaMA uses nn.Linear for these following ops whose weight matrix is transposed compared to gpt2.
|
||||
# In order to use the preprocess codes of gpt2, we transpose them firstly.
|
||||
def transpose_weights(hf_name, param):
|
||||
weight_to_transpose = ["o_proj", "gate_proj", "down_proj", "up_proj"]
|
||||
if any([k in hf_name for k in weight_to_transpose]):
|
||||
if len(param.shape) == 2:
|
||||
param = param.transpose(0, 1)
|
||||
return param
|
||||
|
||||
|
||||
def hf_gpt_converter(args):
|
||||
infer_tp = args.tensor_parallelism
|
||||
saved_dir = Path(args.out_dir) / f"{infer_tp}-XPU"
|
||||
saved_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(args.in_file, device_map="auto")
|
||||
|
||||
act_range = {}
|
||||
llama_qkv_para = {}
|
||||
# smoother for inputs of self_attn.o_proj and mlp.down_proj
|
||||
llama_smoother = {}
|
||||
|
||||
if args.smoothquant is not None or args.calibrate_kv_cache:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
|
||||
"TOKENIZERS_PARALLELISM", "false")
|
||||
act_range = capture_activation_range(
|
||||
model,
|
||||
LlamaTokenizer.from_pretrained(args.in_file, padding_side='left'))
|
||||
if args.smoothquant is not None:
|
||||
smooth_llama_model(model, act_range, args.smoothquant,
|
||||
llama_qkv_para, llama_smoother)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config["llama"] = {}
|
||||
for key in vars(args):
|
||||
config["llama"][key] = f"{vars(args)[key]}"
|
||||
for k, v in vars(model.config).items():
|
||||
config["llama"][k] = f"{v}"
|
||||
config["llama"]["weight_data_type"] = args.storage_type
|
||||
config["llama"]["multi_query_mode"] = str(args.multi_query_mode)
|
||||
with open(saved_dir / "config.ini", 'w') as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
storage_type = str_to_np_dtype(args.storage_type)
|
||||
|
||||
global_ft_weights = [
|
||||
'vocab_embedding.weight', 'ln_f.weight', 'lm_head.weight'
|
||||
]
|
||||
|
||||
int8_outputs = None
|
||||
if args.calibrate_kv_cache:
|
||||
int8_outputs = "kv_cache_only"
|
||||
if args.smoothquant is not None:
|
||||
int8_outputs = "all"
|
||||
|
||||
starmap_args = []
|
||||
for name, param in model.named_parameters():
|
||||
if "weight" not in name and "bias" not in name:
|
||||
continue
|
||||
ft_name = gpt_to_ft_name(name)
|
||||
|
||||
if name.replace(".weight", "") in llama_smoother.keys():
|
||||
smoother = llama_smoother[name.replace(".weight", "")]
|
||||
smoother = smoother.detach().cpu().numpy()
|
||||
starmap_args.append(
|
||||
(0, saved_dir, infer_tp,
|
||||
f"{ft_name}.smoother".replace(".weight", ""), smoother, None, {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": args.multi_query_mode,
|
||||
"local_dim": None,
|
||||
}))
|
||||
|
||||
param = transpose_weights(name, param)
|
||||
|
||||
param = param.detach().cpu().numpy().astype(storage_type)
|
||||
|
||||
if ft_name in global_ft_weights:
|
||||
param.tofile(saved_dir / f"{ft_name}.bin")
|
||||
elif ft_name.split('.')[-2] == 'query_key_value':
|
||||
# Is there other ways to get local_dim? local_dim = hidden_size in llama2
|
||||
local_dim = model.config.hidden_size if args.multi_query_mode else None
|
||||
if args.smoothquant is None:
|
||||
merge_qkv_scales(name, model, act_range, llama_qkv_para)
|
||||
qkv = (0, saved_dir, infer_tp, ft_name,
|
||||
llama_qkv_para.get(
|
||||
name.replace(".weight", "").replace(
|
||||
".q_proj",
|
||||
".qkv_proj")).cpu().numpy().astype(storage_type),
|
||||
act_range.get(
|
||||
name.replace(".weight",
|
||||
"").replace(".q_proj", ".qkv_proj")), {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode":
|
||||
args.multi_query_mode,
|
||||
"local_dim": local_dim,
|
||||
})
|
||||
starmap_args.append(qkv)
|
||||
elif ft_name.split('.')[-2] == 'kv':
|
||||
continue
|
||||
else:
|
||||
starmap_args.append((0, saved_dir, infer_tp, ft_name, param,
|
||||
act_range.get(name.replace(".weight", "")), {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": args.multi_query_mode,
|
||||
"local_dim": None,
|
||||
}))
|
||||
|
||||
starmap_args = tqdm(starmap_args, desc="saving weights")
|
||||
if args.processes > 1:
|
||||
with multiprocessing.Pool(args.processes) as pool:
|
||||
pool.starmap(split_and_save_weight, starmap_args)
|
||||
else:
|
||||
# simpler for debug situations
|
||||
for starmap_arg in starmap_args:
|
||||
split_and_save_weight(*starmap_arg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument('--out-dir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='file name of output directory',
|
||||
required=True)
|
||||
parser.add_argument('--in-file',
|
||||
'-i',
|
||||
type=str,
|
||||
help='file name of input checkpoint file',
|
||||
required=True)
|
||||
parser.add_argument('--tensor-parallelism',
|
||||
'-tp',
|
||||
type=int,
|
||||
help='Requested tensor parallelism for inference',
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
"-p",
|
||||
type=int,
|
||||
help="How many processes to spawn for conversion (default: 4)",
|
||||
default=4)
|
||||
parser.add_argument(
|
||||
"--calibrate-kv-cache",
|
||||
"-kv",
|
||||
action="store_true",
|
||||
help=
|
||||
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
"-sq",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||||
" to Smoothquant the model, and output int8 weights."
|
||||
" A good first try is 0.5. Must be in [0, 1]")
|
||||
parser.add_argument("--storage-type",
|
||||
"-t",
|
||||
type=str,
|
||||
default="fp32",
|
||||
choices=["fp32", "fp16"])
|
||||
parser.add_argument("--multi-query-mode",
|
||||
action="store_true",
|
||||
help="Use multi-query-attention.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print("\n=============== Argument ===============")
|
||||
for key in vars(args):
|
||||
print("{}: {}".format(key, vars(args)[key]))
|
||||
print("========================================")
|
||||
|
||||
assert (args.calibrate_kv_cache or args.smoothquant), \
|
||||
"Either INT8 kv cache or SmoothQuant must be enabled for this script. Otherwise you can directly build engines from HuggingFace checkpoints, no need to do this FT-format conversion. "
|
||||
|
||||
hf_gpt_converter(args)
|
||||
135
examples/llama/quantize.py
Normal file
135
examples/llama/quantize.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from tensorrt_llm._utils import str_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.quantized.ammo import quantize_and_export
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
max_length=block_size)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
batch_encoded = batch_encoded.cuda()
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, **kwargs):
|
||||
logger.info(f"Loading tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path,
|
||||
padding_side="left",
|
||||
**kwargs)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="float16"):
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
ckpt_path,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model.eval()
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
return model
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory of a HF model checkpoint")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
type=str,
|
||||
choices=['fp8', 'int4_awq'],
|
||||
default='fp8',
|
||||
help='Quantization format. Currently only fp8 is supported. '
|
||||
'For int8 smoothquant, use smoothquant.py instead. ')
|
||||
parser.add_argument("--calib_size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Number of samples for calibration.")
|
||||
parser.add_argument("--export_path", default="exported_model")
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
args = get_args()
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
tokenizer = get_tokenizer(args.model_dir)
|
||||
model = get_model(args.model_dir, args.dtype)
|
||||
|
||||
calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
|
||||
calib_size=args.calib_size)
|
||||
model = quantize_and_export(model,
|
||||
qformat=args.qformat,
|
||||
calib_dataloader=calib_dataloader,
|
||||
export_path=args.export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
examples/llama/requirements.txt
Normal file
3
examples/llama/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
datasets==2.14.5
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece~=0.1.99
|
||||
328
examples/llama/run.py
Normal file
328
examples/llama/run.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import LlamaTokenizer
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
from xtrt_llm.runtime import ModelConfig, SamplingConfig
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
EOS_TOKEN = 2
|
||||
PAD_TOKEN = 2
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def throttle_generator(generator, stream_interval):
|
||||
for i, out in enumerate(generator):
|
||||
if not i % stream_interval:
|
||||
yield out
|
||||
|
||||
if i % stream_interval:
|
||||
yield out
|
||||
|
||||
|
||||
def read_config(config_path: Path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
remove_input_padding = config['plugin_config']['remove_input_padding']
|
||||
dtype = config['builder_config']['precision']
|
||||
tp_size = config['builder_config']['tensor_parallel']
|
||||
pp_size = config['builder_config']['pipeline_parallel']
|
||||
world_size = tp_size * pp_size
|
||||
assert world_size == xtrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
|
||||
num_heads = config['builder_config']['num_heads'] // tp_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // tp_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
num_kv_heads = config['builder_config'].get('num_kv_heads', num_heads)
|
||||
paged_kv_cache = config['plugin_config']['paged_kv_cache']
|
||||
tokens_per_block = config['plugin_config']['tokens_per_block']
|
||||
quant_mode = QuantMode(config['builder_config']['quant_mode'])
|
||||
gather_all_token_logits = config['builder_config'][
|
||||
'gather_all_token_logits']
|
||||
if config['builder_config'].get('multi_query_mode', False):
|
||||
xtrt_llm.logger.warning(
|
||||
"`multi_query_mode` config is deprecated. Please rebuild the engine."
|
||||
)
|
||||
num_kv_heads = 1
|
||||
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
|
||||
use_custom_all_reduce = config['plugin_config'].get('use_custom_all_reduce',
|
||||
False)
|
||||
|
||||
model_config = ModelConfig(num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
paged_kv_cache=paged_kv_cache,
|
||||
tokens_per_block=tokens_per_block,
|
||||
remove_input_padding=remove_input_padding,
|
||||
dtype=dtype,
|
||||
quant_mode=quant_mode,
|
||||
use_custom_all_reduce=use_custom_all_reduce,
|
||||
gather_all_token_logits=gather_all_token_logits)
|
||||
|
||||
return model_config, tp_size, pp_size, dtype
|
||||
|
||||
|
||||
def parse_input(input_text: str, input_file: str, tokenizer, end_id: int,
|
||||
remove_input_padding: bool):
|
||||
input_tokens = []
|
||||
if input_file is None:
|
||||
input_tokens.append(
|
||||
tokenizer.encode(input_text, add_special_tokens=False))
|
||||
else:
|
||||
if input_file.endswith('.csv'):
|
||||
with open(input_file, 'r') as csv_file:
|
||||
csv_reader = csv.reader(csv_file, delimiter=',')
|
||||
for line in csv_reader:
|
||||
input_tokens.append(np.array(line, dtype='int32'))
|
||||
elif input_file.endswith('.npy'):
|
||||
inputs = np.load(input_file)
|
||||
for row in inputs:
|
||||
row = row[row != end_id]
|
||||
input_tokens.append(row)
|
||||
elif input_file.endswith('.txt'):
|
||||
with open(input_file, 'r', encoding='utf-8') as file:
|
||||
for line in file.readlines():
|
||||
line = line.strip("\n")
|
||||
input_tokens.append(
|
||||
tokenizer.encode(line, add_special_tokens=False))
|
||||
else:
|
||||
print('Input file format not supported.')
|
||||
raise SystemExit
|
||||
|
||||
input_ids = None
|
||||
input_lengths = torch.tensor([len(x) for x in input_tokens],
|
||||
dtype=torch.int32).cuda()
|
||||
if remove_input_padding:
|
||||
input_ids = np.concatenate(input_tokens)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int32,
|
||||
device='cuda').unsqueeze(0)
|
||||
else:
|
||||
input_ids = torch.nested.to_padded_tensor(
|
||||
torch.nested.nested_tensor(input_tokens, dtype=torch.int32),
|
||||
end_id).cuda()
|
||||
print(input_ids)
|
||||
|
||||
return input_ids, input_lengths
|
||||
|
||||
|
||||
def print_output(output_ids, input_lengths, max_output_len, tokenizer,
|
||||
output_csv, output_npy, remove_input_padding):
|
||||
num_beams = output_ids.size(1)
|
||||
if output_csv is None and output_npy is None:
|
||||
for b in range(input_lengths.size(0)):
|
||||
inputs = output_ids[b][0][:input_lengths[b]].tolist()
|
||||
input_text = tokenizer.decode(inputs)
|
||||
print(f'Input: \"{input_text}\"')
|
||||
for beam in range(num_beams):
|
||||
output_begin = max(input_lengths)
|
||||
output_end = output_begin + max_output_len
|
||||
outputs = output_ids[b][beam][output_begin:output_end].tolist()
|
||||
output_text = tokenizer.decode(outputs)
|
||||
print(f'Output: \"{output_text}\"')
|
||||
|
||||
output_ids = output_ids.reshape((-1, output_ids.size(2)))
|
||||
|
||||
if output_csv is not None:
|
||||
output_file = Path(output_csv)
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
outputs = output_ids.tolist()
|
||||
with open(output_file, 'w') as csv_file:
|
||||
writer = csv.writer(csv_file, delimiter=',')
|
||||
writer.writerows(outputs)
|
||||
|
||||
if output_npy is not None:
|
||||
output_file = Path(output_npy)
|
||||
output_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
outputs = np.array(output_ids.cpu().contiguous(), dtype='int32')
|
||||
np.save(output_file, outputs)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--max_output_len', type=int, required=True)
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='llama_outputs')
|
||||
parser.add_argument('--tokenizer_dir',
|
||||
type=str,
|
||||
default=".",
|
||||
help="Directory containing the tokenizer.model.")
|
||||
parser.add_argument('--input_text',
|
||||
type=str,
|
||||
default='Born in north-east France, Soyer trained as a')
|
||||
parser.add_argument(
|
||||
'--input_tokens',
|
||||
dest='input_file',
|
||||
type=str,
|
||||
help=
|
||||
'CSV or Numpy file containing tokenized input. Alternative to text input.',
|
||||
default=None)
|
||||
parser.add_argument('--output_csv',
|
||||
type=str,
|
||||
help='CSV file where the tokenized output is stored.',
|
||||
default=None)
|
||||
parser.add_argument('--output_npy',
|
||||
type=str,
|
||||
help='Numpy file where the tokenized output is stored.',
|
||||
default=None)
|
||||
parser.add_argument('--num_beams',
|
||||
type=int,
|
||||
help="Use beam search if num_beams >1",
|
||||
default=1)
|
||||
parser.add_argument('--streaming', default=False, action='store_true')
|
||||
parser.add_argument('--streaming_interval',
|
||||
type=int,
|
||||
help="How often to return tokens when streaming.",
|
||||
default=5)
|
||||
parser.add_argument(
|
||||
'--performance_test_scale',
|
||||
type=str,
|
||||
help=
|
||||
"Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)",
|
||||
default="")
|
||||
parser.add_argument('--not_warmup', default=False, action='store_true')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def generate(
|
||||
max_output_len: int,
|
||||
log_level: str = 'error',
|
||||
engine_dir: str = 'llama_outputs',
|
||||
input_text: str = 'Born in north-east France, Soyer trained as a',
|
||||
input_file: str = None,
|
||||
output_csv: str = None,
|
||||
output_npy: str = None,
|
||||
tokenizer_dir: str = None,
|
||||
num_beams: int = 1,
|
||||
streaming: bool = False,
|
||||
streaming_interval: int = 5,
|
||||
performance_test_scale: str = "",
|
||||
not_warmup: bool = False,
|
||||
):
|
||||
xtrt_llm.logger.set_level(log_level)
|
||||
|
||||
engine_dir = Path(engine_dir)
|
||||
config_path = engine_dir / 'config.json'
|
||||
model_config, tp_size, pp_size, dtype = read_config(config_path)
|
||||
world_size = tp_size * pp_size
|
||||
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
if world_size > 1:
|
||||
os.environ["XCCL_GROUP_ID"] = str(runtime_rank // world_size)
|
||||
os.environ["XCCL_NRANKS"] = str(world_size)
|
||||
os.environ["XCCL_CUR_RANK"] = str(runtime_rank % world_size)
|
||||
os.environ["XCCL_DEVICE_ID"] = str(runtime_rank)
|
||||
os.environ["MP_RUN"] = str(1)
|
||||
# if runtime_rank == 0:
|
||||
# os.environ["XTCL_PRINT_L3_PLAN"] = "3"
|
||||
runtime_mapping = xtrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, legacy=False)
|
||||
|
||||
sampling_config = SamplingConfig(end_id=EOS_TOKEN,
|
||||
pad_id=PAD_TOKEN,
|
||||
num_beams=num_beams)
|
||||
|
||||
engine_name = get_engine_name('llama', dtype, tp_size, pp_size,
|
||||
runtime_rank)
|
||||
serialize_path = str(engine_dir) + "/" + engine_name
|
||||
# with open(serialize_path, 'rb') as f:
|
||||
# engine_buffer = f.read()
|
||||
decoder = xtrt_llm.runtime.GenerationSession(model_config,
|
||||
serialize_path,
|
||||
runtime_mapping,
|
||||
debug_mode=False,
|
||||
debug_tensors_to_save=None)
|
||||
if runtime_rank == 0:
|
||||
print(f"Running the {dtype} engine ...")
|
||||
|
||||
input_ids, input_lengths = parse_input(input_text, input_file, tokenizer,
|
||||
EOS_TOKEN,
|
||||
model_config.remove_input_padding)
|
||||
|
||||
if performance_test_scale != "":
|
||||
performance_test_scale_list = performance_test_scale.split("E")
|
||||
for scale in performance_test_scale_list:
|
||||
xtrt_llm.logger.info(f"Running performance test with scale {scale}")
|
||||
bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")]
|
||||
_input_ids = torch.from_numpy(
|
||||
np.zeros((bs, seqlen)).astype("int32")).cuda()
|
||||
_input_lengths = torch.from_numpy(
|
||||
np.full((bs, ), seqlen).astype("int32")).cuda()
|
||||
_max_input_length = torch.max(_input_lengths).item()
|
||||
if model_config.remove_input_padding:
|
||||
_input_ids = _input_ids.view((1, -1)).contiguous()
|
||||
import time
|
||||
_t_begin = time.time()
|
||||
decoder.setup(_input_lengths.size(0), _max_input_length,
|
||||
_max_output_len, num_beams)
|
||||
_output_gen_ids = decoder.decode(_input_ids,
|
||||
_input_lengths,
|
||||
sampling_config,
|
||||
streaming=streaming)
|
||||
_t_end = time.time()
|
||||
xtrt_llm.logger.info(
|
||||
f"Total latency: {(_t_end - _t_begin) * 1000:.3f} ms")
|
||||
xtrt_llm.logger.info(
|
||||
f"Throughput: {bs * _max_output_len / (_t_end - _t_begin):.3f} tokens/sec"
|
||||
)
|
||||
exit(0)
|
||||
|
||||
max_input_length = torch.max(input_lengths).item()
|
||||
|
||||
decoder.setup(input_lengths.size(0), max_input_length, max_output_len,
|
||||
num_beams)
|
||||
output_gen_ids = decoder.decode(input_ids,
|
||||
input_lengths,
|
||||
sampling_config,
|
||||
streaming=streaming,
|
||||
stop_words_list=[EOS_TOKEN])
|
||||
torch.cuda.synchronize()
|
||||
if streaming:
|
||||
for output_ids in throttle_generator(output_gen_ids,
|
||||
streaming_interval):
|
||||
if runtime_rank == 0:
|
||||
print_output(output_ids, input_lengths, max_output_len,
|
||||
tokenizer, output_csv, output_npy,
|
||||
model_config.remove_input_padding)
|
||||
else:
|
||||
output_ids = output_gen_ids
|
||||
if runtime_rank == 0:
|
||||
print_output(output_ids, input_lengths, max_output_len, tokenizer,
|
||||
output_csv, output_npy,
|
||||
model_config.remove_input_padding)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
generate(**vars(args))
|
||||
20
examples/llama/run.sh
Normal file
20
examples/llama/run.sh
Normal file
@@ -0,0 +1,20 @@
|
||||
SCALE=""
|
||||
for _b in {1..8}; do
|
||||
for _len in {64..1024..32}; do
|
||||
SCALE+="${_b}x${_len}x${_len}E"
|
||||
done
|
||||
done
|
||||
for i in {8..1}; do
|
||||
SCALE+="${i}x2000x64E"
|
||||
done
|
||||
SCALE+="1x2000x64"
|
||||
|
||||
PYTORCH_NO_XPU_MEMORY_CACHING=1 XMLIR_D_XPU_L3_SIZE=0 \
|
||||
python3 run.py \
|
||||
--engine_dir=/root/.cache/llama_outputs/ \
|
||||
--max_output_len 256 \
|
||||
--performance_test_scale 1x2000x64E2x2000x64E4x2000x64E8x2000x64E11x2000x64E1x2000x64E2x2000x64E4x2000x64E8x2000x64E11x2000x64 \
|
||||
--tokenizer_dir=/root/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16/ \
|
||||
--log_level=info
|
||||
|
||||
#_remove_padding
|
||||
205
examples/llama/smoothquant.py
Normal file
205
examples/llama/smoothquant.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Utilities for SmoothQuant models
|
||||
'''
|
||||
|
||||
import copy
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_smoothing(scales,
|
||||
gemm_weights,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
dtype=torch.float32,
|
||||
layernorm_1p=False):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
|
||||
if layernorm_weights is not None:
|
||||
assert layernorm_weights.numel() == scales.numel()
|
||||
layernorm_weights.div_(scales).to(dtype)
|
||||
if layernorm_bias is not None:
|
||||
assert layernorm_bias.numel() == scales.numel()
|
||||
layernorm_bias.div_(scales).to(dtype)
|
||||
if layernorm_1p:
|
||||
layernorm_weights += (1 / scales) - 1
|
||||
|
||||
for gemm in gemm_weights:
|
||||
gemm.mul_(scales.view(1, -1)).to(dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm(gemm_weights,
|
||||
act_scales,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
|
||||
orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm_fc1_gate(fc1_weights,
|
||||
gate_weights,
|
||||
act_scales,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
gemm_weights = []
|
||||
if not isinstance(fc1_weights, list):
|
||||
fc1_weights = [fc1_weights]
|
||||
if not isinstance(gate_weights, list):
|
||||
gate_weights = [gate_weights]
|
||||
|
||||
for i in range(len(fc1_weights)):
|
||||
gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0)
|
||||
gemm_weights.append(gemm_weight)
|
||||
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights,
|
||||
layernorm_bias, orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat(
|
||||
[fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
if ln is not None:
|
||||
ln.weight.div_(scales)
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def capture_activation_range(model, tokenizer, num_samples=512, seq_len=512):
|
||||
model.eval()
|
||||
next(model.parameters()).device
|
||||
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
||||
|
||||
test_token_num = 923
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def stat_tensor(name, tensor, act_scales, key):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float()
|
||||
|
||||
if act_scales[name][key] is None:
|
||||
act_scales[name][key] = comming_max
|
||||
else:
|
||||
act_scales[name][key] = torch.max(act_scales[name][key],
|
||||
comming_max)
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x, act_scales, "x")
|
||||
stat_tensor(name, y, act_scales, "y")
|
||||
|
||||
if act_scales[name]["w"] is None:
|
||||
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
||||
None).max(dim=1)[0]
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
||||
hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(stat_input_hook, name=name)))
|
||||
|
||||
from datasets import load_dataset
|
||||
dataset_cnn = load_dataset("ccdv/cnn_dailymail", '3.0.0')
|
||||
|
||||
for i in tqdm(range(num_samples), desc="calibrating model"):
|
||||
datapoint = dataset_cnn['train'][i:i + 1]
|
||||
line = copy.copy(datapoint['article'])
|
||||
line[0] = line[0] + ' TL;DR: '
|
||||
line[0] = line[0].strip()
|
||||
line[0] = line[0].replace(" n't", "n't")
|
||||
line_encoded = tokenizer(line,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True)["input_ids"].type(torch.int64)
|
||||
line_encoded = line_encoded[:, -test_token_num:]
|
||||
if torch.cuda.is_available():
|
||||
line_encoded = line_encoded.cuda()
|
||||
model(line_encoded)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
411
examples/llama/summarize.py
Normal file
411
examples/llama/summarize.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
||||
|
||||
import xtrt_llm
|
||||
import xtrt_llm.profiler as profiler
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
from build import get_engine_name # isort:skip
|
||||
|
||||
|
||||
def TRTLLaMA(args, config):
|
||||
dtype = config['builder_config']['precision']
|
||||
tp_size = config['builder_config']['tensor_parallel']
|
||||
pp_size = config['builder_config']['pipeline_parallel']
|
||||
world_size = tp_size * pp_size
|
||||
|
||||
assert world_size == xtrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({world_size}) != Runtime world size ({xtrt_llm.mpi_world_size()})'
|
||||
|
||||
num_heads = config['builder_config']['num_heads'] // tp_size
|
||||
hidden_size = config['builder_config']['hidden_size'] // tp_size
|
||||
vocab_size = config['builder_config']['vocab_size']
|
||||
num_layers = config['builder_config']['num_layers']
|
||||
use_gpt_attention_plugin = bool(
|
||||
config['plugin_config']['gpt_attention_plugin'])
|
||||
remove_input_padding = config['plugin_config']['remove_input_padding']
|
||||
num_kv_heads = config['builder_config'].get('num_kv_heads', num_heads)
|
||||
builder_config = config['builder_config']
|
||||
gather_all_token_logits = builder_config.get('gather_all_token_logits',
|
||||
False)
|
||||
paged_kv_cache = config['plugin_config']['paged_kv_cache']
|
||||
tokens_per_block = config['plugin_config']['tokens_per_block']
|
||||
use_custom_all_reduce = config['plugin_config'].get('use_custom_all_reduce',
|
||||
False)
|
||||
|
||||
quant_mode = QuantMode(config['builder_config']['quant_mode'])
|
||||
if config['builder_config'].get('multi_query_mode', False):
|
||||
xtrt_llm.logger.warning(
|
||||
"`multi_query_mode` config is deprecated. Please rebuild the engine."
|
||||
)
|
||||
num_kv_heads = 1
|
||||
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
|
||||
|
||||
model_config = xtrt_llm.runtime.ModelConfig(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
hidden_size=hidden_size,
|
||||
paged_kv_cache=paged_kv_cache,
|
||||
tokens_per_block=tokens_per_block,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=remove_input_padding,
|
||||
use_custom_all_reduce=use_custom_all_reduce,
|
||||
dtype=dtype,
|
||||
quant_mode=quant_mode,
|
||||
gather_all_token_logits=gather_all_token_logits)
|
||||
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
runtime_mapping = xtrt_llm.Mapping(world_size,
|
||||
runtime_rank,
|
||||
tp_size=tp_size,
|
||||
pp_size=pp_size)
|
||||
if world_size > 1:
|
||||
os.environ["XCCL_GROUP_ID"] = str(runtime_rank // world_size)
|
||||
os.environ["XCCL_NRANKS"] = str(world_size)
|
||||
os.environ["XCCL_CUR_RANK"] = str(runtime_rank % world_size)
|
||||
os.environ["XCCL_DEVICE_ID"] = str(runtime_rank)
|
||||
os.environ["MP_RUN"] = str(1)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
|
||||
engine_name = get_engine_name('llama', dtype, tp_size, pp_size,
|
||||
runtime_rank)
|
||||
serialize_path = str(os.path.join(args.engine_dir, engine_name))
|
||||
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
profiler.start('load xtrt_llm engine')
|
||||
# with open(serialize_path, 'rb') as f:
|
||||
# engine_buffer = f.read()
|
||||
decoder = xtrt_llm.runtime.GenerationSession(model_config, serialize_path,
|
||||
runtime_mapping)
|
||||
profiler.stop('load xtrt_llm engine')
|
||||
xtrt_llm.logger.info(
|
||||
f'Load engine takes: {profiler.elapsed_time_in_sec("load xtrt_llm engine")} sec'
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def main(args):
|
||||
runtime_rank = xtrt_llm.mpi_rank()
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
||||
test_trt_llm = args.test_trt_llm
|
||||
hf_model_location = args.hf_model_location
|
||||
profiler.start('load tokenizer')
|
||||
tokenizer = LlamaTokenizer.from_pretrained(hf_model_location,
|
||||
legacy=False,
|
||||
padding_side='left')
|
||||
profiler.stop('load tokenizer')
|
||||
xtrt_llm.logger.info(
|
||||
f'Load tokenizer takes: {profiler.elapsed_time_in_sec("load tokenizer")} sec'
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset_cnn = load_dataset("ccdv/cnn_dailymail",
|
||||
'3.0.0',
|
||||
cache_dir=args.dataset_path)
|
||||
|
||||
max_batch_size = args.batch_size
|
||||
|
||||
# runtime parameters
|
||||
# repetition_penalty = 1
|
||||
top_k = args.top_k
|
||||
output_len = 100
|
||||
test_token_num = 923
|
||||
# top_p = 0.0
|
||||
# random_seed = 5
|
||||
temperature = 1
|
||||
num_beams = args.num_beams
|
||||
|
||||
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
|
||||
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
|
||||
|
||||
if test_trt_llm:
|
||||
config_path = os.path.join(args.engine_dir, 'config.json')
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
xtrt_llm_llama = TRTLLaMA(args, config)
|
||||
|
||||
if test_hf:
|
||||
profiler.start('load HF model')
|
||||
model = AutoModelForCausalLM.from_pretrained(hf_model_location)
|
||||
profiler.stop('load HF model')
|
||||
xtrt_llm.logger.info(
|
||||
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
|
||||
)
|
||||
if args.data_type == 'fp16':
|
||||
model.half()
|
||||
model.cuda()
|
||||
|
||||
def summarize_xtrt_llm(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
line_encoded = []
|
||||
input_lengths = []
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
input_id = tokenizer.encode(line[i],
|
||||
return_tensors='pt').type(torch.int32)
|
||||
input_id = input_id[:, -test_token_num:]
|
||||
|
||||
line_encoded.append(input_id)
|
||||
input_lengths.append(input_id.shape[-1])
|
||||
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
max_length = max(input_lengths)
|
||||
if xtrt_llm_llama.remove_input_padding:
|
||||
line_encoded = [
|
||||
torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded
|
||||
]
|
||||
else:
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
for i in range(batch_size):
|
||||
pad_size = max_length - input_lengths[i]
|
||||
|
||||
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
|
||||
line_encoded[i] = torch.cat(
|
||||
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
|
||||
axis=-1)
|
||||
|
||||
line_encoded = torch.cat(line_encoded, axis=0).cuda()
|
||||
input_lengths = torch.tensor(input_lengths,
|
||||
dtype=torch.int32).cuda()
|
||||
|
||||
sampling_config = xtrt_llm.runtime.SamplingConfig(end_id=end_id,
|
||||
pad_id=pad_id,
|
||||
top_k=top_k,
|
||||
num_beams=num_beams)
|
||||
|
||||
with torch.no_grad():
|
||||
xtrt_llm_llama.setup(batch_size,
|
||||
max_context_length=max_length,
|
||||
max_new_tokens=output_len,
|
||||
beam_width=num_beams)
|
||||
|
||||
if xtrt_llm_llama.remove_input_padding:
|
||||
output_ids = xtrt_llm_llama.decode_batch(
|
||||
line_encoded, sampling_config)
|
||||
else:
|
||||
output_ids = xtrt_llm_llama.decode(
|
||||
line_encoded,
|
||||
input_lengths,
|
||||
sampling_config,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Extract a list of tensors of shape beam_width x output_ids.
|
||||
if xtrt_llm_llama.mapping.is_first_pp_rank():
|
||||
output_beams_list = [
|
||||
tokenizer.batch_decode(output_ids[batch_idx, :,
|
||||
input_lengths[batch_idx]:],
|
||||
skip_special_tokens=True)
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
return output_beams_list, output_ids[:, :, max_length:].tolist()
|
||||
return [], []
|
||||
|
||||
def summarize_hf(datapoint):
|
||||
batch_size = len(datapoint['article'])
|
||||
if batch_size > 1:
|
||||
logger.warning(
|
||||
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
|
||||
)
|
||||
|
||||
line = copy.copy(datapoint['article'])
|
||||
for i in range(batch_size):
|
||||
line[i] = line[i] + ' TL;DR: '
|
||||
|
||||
line[i] = line[i].strip()
|
||||
line[i] = line[i].replace(" n't", "n't")
|
||||
|
||||
line_encoded = tokenizer(line,
|
||||
return_tensors='pt',
|
||||
padding=True,
|
||||
truncation=True)["input_ids"].type(torch.int64)
|
||||
|
||||
line_encoded = line_encoded[:, -test_token_num:]
|
||||
line_encoded = line_encoded.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(line_encoded,
|
||||
max_length=len(line_encoded[0]) +
|
||||
output_len,
|
||||
top_k=top_k,
|
||||
temperature=temperature,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
num_beams=num_beams,
|
||||
num_return_sequences=num_beams,
|
||||
early_stopping=True)
|
||||
|
||||
tokens_list = output[:, len(line_encoded[0]):].tolist()
|
||||
output = output.reshape([batch_size, num_beams, -1])
|
||||
output_lines_list = [
|
||||
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
|
||||
skip_special_tokens=True)
|
||||
for i in range(num_beams)
|
||||
]
|
||||
|
||||
return output_lines_list, tokens_list
|
||||
|
||||
if test_trt_llm:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_xtrt_llm(datapoint)
|
||||
if runtime_rank == 0:
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
logger.info("XTRT-LLM Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
|
||||
if test_hf:
|
||||
datapoint = dataset_cnn['test'][0:1]
|
||||
summary, _ = summarize_hf(datapoint)
|
||||
logger.info("---------------------------------------------------------")
|
||||
logger.info("HF Generated : ")
|
||||
logger.info(f" Article : {datapoint['article']}")
|
||||
logger.info(f"\n Highlights : {datapoint['highlights']}")
|
||||
logger.info(f"\n Summary : {summary}")
|
||||
logger.info("---------------------------------------------------------")
|
||||
|
||||
metric_xtrt_llm = [load_metric("rouge") for _ in range(num_beams)]
|
||||
metric_hf = [load_metric("rouge") for _ in range(num_beams)]
|
||||
for i in range(num_beams):
|
||||
metric_xtrt_llm[i].seed = 0
|
||||
metric_hf[i].seed = 0
|
||||
|
||||
ite_count = 0
|
||||
data_point_idx = 0
|
||||
while (data_point_idx < len(dataset_cnn['test'])) and (ite_count <
|
||||
args.max_ite):
|
||||
if runtime_rank == 0:
|
||||
logger.debug(
|
||||
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
|
||||
)
|
||||
datapoint = dataset_cnn['test'][data_point_idx:(data_point_idx +
|
||||
max_batch_size)]
|
||||
|
||||
if test_trt_llm:
|
||||
profiler.start('xtrt_llm')
|
||||
summary_xtrt_llm, tokens_xtrt_llm = summarize_xtrt_llm(datapoint)
|
||||
profiler.stop('xtrt_llm')
|
||||
|
||||
if test_hf:
|
||||
profiler.start('hf')
|
||||
summary_hf, tokens_hf = summarize_hf(datapoint)
|
||||
profiler.stop('hf')
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
for batch_idx in range(len(summary_xtrt_llm)):
|
||||
for beam_idx in range(num_beams):
|
||||
metric_xtrt_llm[beam_idx].add_batch(
|
||||
predictions=[summary_xtrt_llm[batch_idx][beam_idx]],
|
||||
references=[datapoint['highlights'][batch_idx]])
|
||||
if test_hf:
|
||||
for beam_idx in range(num_beams):
|
||||
for batch_idx in range(len(summary_hf[beam_idx])):
|
||||
metric_hf[beam_idx].add_batch(
|
||||
predictions=[summary_hf[beam_idx][batch_idx]],
|
||||
references=[datapoint['highlights'][batch_idx]])
|
||||
|
||||
logger.debug('-' * 100)
|
||||
logger.debug(f"Article : {datapoint['article']}")
|
||||
if test_trt_llm:
|
||||
logger.debug(f'XTRT-LLM Summary: {summary_xtrt_llm}')
|
||||
if test_hf:
|
||||
logger.debug(f'HF Summary: {summary_hf}')
|
||||
logger.debug(f"highlights : {datapoint['highlights']}")
|
||||
|
||||
data_point_idx += max_batch_size
|
||||
ite_count += 1
|
||||
|
||||
if runtime_rank == 0:
|
||||
if test_trt_llm:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'XTRT-LLM (total latency: {profiler.elapsed_time_in_sec("xtrt_llm")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
logger.info(f"XTRT-LLM beam {beam_idx} result")
|
||||
computed_metrics_xtrt_llm = metric_xtrt_llm[beam_idx].compute()
|
||||
for key in computed_metrics_xtrt_llm.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_xtrt_llm[key].mid[2]*100}'
|
||||
)
|
||||
|
||||
if args.check_accuracy and beam_idx == 0:
|
||||
assert computed_metrics_xtrt_llm['rouge1'].mid[
|
||||
2] * 100 > args.xtrt_llm_rouge1_threshold
|
||||
if test_hf:
|
||||
np.random.seed(0) # rouge score use sampling to compute the score
|
||||
logger.info(
|
||||
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
|
||||
)
|
||||
for beam_idx in range(num_beams):
|
||||
logger.info(f"HF beam {beam_idx} result")
|
||||
computed_metrics_hf = metric_hf[beam_idx].compute()
|
||||
for key in computed_metrics_hf.keys():
|
||||
logger.info(
|
||||
f' {key} : {computed_metrics_hf[key].mid[2]*100}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf_model_location',
|
||||
type=str,
|
||||
default='/workspace/models/llama-models/llama-7b-hf')
|
||||
parser.add_argument('--test_hf', action='store_true')
|
||||
parser.add_argument('--test_trt_llm', action='store_true')
|
||||
parser.add_argument('--data_type',
|
||||
type=str,
|
||||
choices=['fp32', 'fp16'],
|
||||
default='fp16')
|
||||
parser.add_argument('--dataset_path', type=str, default='')
|
||||
parser.add_argument('--log_level', type=str, default='info')
|
||||
parser.add_argument('--engine_dir', type=str, default='llama_outputs')
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_ite', type=int, default=20)
|
||||
parser.add_argument('--check_accuracy', action='store_true')
|
||||
parser.add_argument('--xtrt_llm_rouge1_threshold', type=float, default=14.5)
|
||||
parser.add_argument('--num_beams', type=int, default=1)
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
1360
examples/llama/weight.py
Normal file
1360
examples/llama/weight.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user