add pkgs
This commit is contained in:
0
examples/qwen/-d
Normal file
0
examples/qwen/-d
Normal file
202
examples/qwen/README.md
Normal file
202
examples/qwen/README.md
Normal file
@@ -0,0 +1,202 @@
|
||||
# Qwen
|
||||
|
||||
This document shows how to build and run a Qwen model in XTRT-LLM on both single XPU and single node multi-XPU.
|
||||
|
||||
Support Qwen1.5 model as well
|
||||
|
||||
## Overview
|
||||
|
||||
The XTRT-LLM Qwen example code is located in [`qwen`](./). There is one main file:
|
||||
|
||||
* [`build.py`](./build.py) to build the XTRT-LLM engine(s) needed to run the Qwen model.
|
||||
|
||||
In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation:
|
||||
|
||||
* [`../run.py`](../run.py) to run the inference on an input text;
|
||||
* [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset.
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
* INT8 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## Usage
|
||||
|
||||
The XTRT-LLM Qwen example code locates at [qwen](./). It takes HF weights as input, and builds the corresponding XTRT engines. The number of XTRT engines depends on the number of XPUs used to run inference.
|
||||
|
||||
### Build XTRT engine(s)
|
||||
|
||||
Need to prepare the HF Qwen checkpoint first by following the guides here [Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) or [Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
|
||||
|
||||
Create a `downloads` directory to store the weights downloaded from huaggingface.
|
||||
```bash
|
||||
mkdir -p ./downloads
|
||||
```
|
||||
|
||||
Store Qwen-7B-Chat or Qwen-14B-Chat separately.
|
||||
- for Qwen-7B-Chat
|
||||
```bash
|
||||
mv Qwen-7B-Chat ./downloads/qwen-7b/
|
||||
```
|
||||
- for Qwen-14B-Chat
|
||||
```bash
|
||||
mv Qwen-14B-Chat ./downloads/qwen-14b/
|
||||
```
|
||||
- for Qwen1.5-7B-Chat
|
||||
```bash
|
||||
mv Qwen1.5-7B-Chat ./downloads/Qwen1.5-7B-Chat/
|
||||
```
|
||||
|
||||
XTRT-LLM Qwen builds XTRT engine(s) from HF checkpoint.
|
||||
|
||||
Normally `build.py` only requires single XPU, but if you've already got all the XPUs needed while inferencing, you could enable parallelly building to make the engine building process faster by adding `--parallel_build` argument. Please note that currently `parallel_build` feature only supports single node.
|
||||
|
||||
** Notice: Qwen1.5 require arg "--version=1.5 **
|
||||
** Notice: `pip install transformers-stream-generator` in build phase**
|
||||
|
||||
Here're some examples:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# use_gpt_attention_plugin is necessary in Qwen.
|
||||
# Try use_gemm_plugin to prevent accuracy issue.
|
||||
# It is recommend to use --use_gpt_attention_plugin for better performance
|
||||
|
||||
# Build the Qwen 7B model using a single XPU and FP16.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Build the Qwen1.5 7B model using a single XPU and FP16.
|
||||
python build.py --hf_model_dir ./downloads/Qwen1.5-7B-Chat \
|
||||
--version 1.5 \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/Qwen1.5-7B-Chat/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Build the Qwen 7B model using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--weight_only_precision int8 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Build Qwen 7B using 2-way tensor parallelism.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2
|
||||
|
||||
|
||||
# Build Qwen 14B using 2-way tensor parallelism.
|
||||
python build.py --hf_model_dir ./downloads/qwen-14b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-14b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
The smoothquant supports both Qwen v1 and Qwen v2. Unlike the FP16 build where the HF weights are processed and loaded into the XTRT-LLM directly, the SmoothQuant needs to load INT8 weights which should be pre-processed before building an engine.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
python3 hf_qwen_convert.py -i ./downloads/qwen-7b/ -o ./downloads/qwen-7b/sq0.5/ -sq 0.5 --tensor-parallelism 1 --storage-type float16
|
||||
```
|
||||
|
||||
Note `hf_qwen_convert.py` run with PyTorch, and
|
||||
1. `torch-cpu` has better accuracy than xpytorch generally.
|
||||
2. XPyTorch often use more than 32GB GM, thus more XPU are necessary to finish it.
|
||||
3. add `-p=1` if run with XPyTorch.
|
||||
|
||||
[`build.py`](./build.py) add new options for the support of INT8 inference of SmoothQuant models.
|
||||
|
||||
`--use_smooth_quant` is the starting point of INT8 inference. By default, it
|
||||
will run the model in the _per-tensor_ mode.
|
||||
|
||||
`--per-token` and `--per-channel` are not supported yet.
|
||||
|
||||
Examples of build invocations:
|
||||
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --ft_dir_path=./downloads/qwen-7b/sq0.5/1-XPU/ \
|
||||
--use_smooth_quant \
|
||||
--hf_model_dir ./downloads/qwen-7b/ \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/
|
||||
```
|
||||
|
||||
- run
|
||||
```bash
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/
|
||||
```
|
||||
|
||||
- summarize
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/ \
|
||||
--max_input_length 2048 \
|
||||
--output_len 2048
|
||||
```
|
||||
|
||||
|
||||
### Run
|
||||
|
||||
**Notice: `pip install tiktoken` in run phase**
|
||||
|
||||
To run a XTRT-LLM Qwen model using the engines generated by `build.py`
|
||||
|
||||
```bash
|
||||
# With fp16 inference
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?答:" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
|
||||
# Qwen1.5 With fp16 inference
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?答:" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/Qwen1.5-7B-Chat/ \
|
||||
--engine_dir=./downloads/Qwen1.5-7B-Chat/trt_engines/fp16/1-XPU/
|
||||
|
||||
# With int8 weight only inference
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?答:" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Run Qwen 7B model in FP16 using two XPUs.
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python ../run.py --input_text "你好,请问你叫什么?答:" \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/qwen-7b/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
**Demo output of run.py:**
|
||||
```bash
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?答:" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir ./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
```
|
||||
Loading engine from ./downloads/qwen-7b/trt_engines/fp16/1-XPU/qwen_float16_tp1_rank0.engine
|
||||
Input: "<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
你好,请问你叫什么?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"
|
||||
Output: "我是来自阿里云的大规模语言模型,我叫通义千问。"
|
||||
```
|
||||
189
examples/qwen/README_CN.md
Normal file
189
examples/qwen/README_CN.md
Normal file
@@ -0,0 +1,189 @@
|
||||
# Qwen
|
||||
|
||||
本文档描述了如何使用昆仑芯XTRT-LLM中在单XPU和单节点多XPU上构建和运行Qwen模型。
|
||||
|
||||
## 概述
|
||||
|
||||
XTRT-LLM Qwen 示例代码的位置在文件夹`examples/qwen`下,此文件夹下有一个主要文件:
|
||||
|
||||
* [`build.py`](./build.py) 构建运行Qwen模型所需的XTRT-LLM引擎
|
||||
|
||||
除此之外,还有两个可以用来推理和评估的共享文件在父节点 [`examples`](../) 下:
|
||||
|
||||
* [`../run.py`](../run.py) 基于输入的文字进行推理
|
||||
* [`../summarize.py`](../summarize.py) 使用此模型对[cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) 数据集中的文章进行总结
|
||||
|
||||
## 支持的矩阵
|
||||
|
||||
* FP16
|
||||
* INT8 Weight-Only
|
||||
* Tensor Parallel
|
||||
|
||||
## 使用说明
|
||||
|
||||
XTRT-LLM Qwen 示例代码位于 [qwen](./)。它使用HF权重作为输入,并且构建对应的XTRT引擎。XTRT引擎的数量取决于为了运行推理而使用的XPU个数。
|
||||
|
||||
### 构建XTRT引擎
|
||||
|
||||
需要先按照下面的指南准备HF Qwen checkpoint: [Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) 或 [Qwen-14B-Chat](https://huggingface.co/Qwen/Qwen-14B-Chat)
|
||||
|
||||
创建一个 `downloads` 目录,用来存储自Huggingface社区下载的权重。
|
||||
|
||||
```bash
|
||||
mkdir -p ./downloads
|
||||
```
|
||||
|
||||
将Qwen-7B-Chat和Qwen-14B-Chat分开存储。
|
||||
|
||||
- 存储 Qwen-7B-Chat
|
||||
|
||||
```bash
|
||||
mv Qwen-7B-Chat ./downloads/qwen-7b/
|
||||
```
|
||||
|
||||
- 存储 Qwen-14B-Chat
|
||||
|
||||
```bash
|
||||
mv Qwen-14B-Chat ./downloads/qwen-14b/
|
||||
```
|
||||
|
||||
XTRT-LLM从HFcheckpoint构建XTRT引擎。
|
||||
|
||||
通常`build.py`只需要一个XPU,但如果您在推理时已经获得了所需的所有XPU,则可以通过添加`--parallel_build`参数来启用并行构建,从而加快引擎构建过程。请注意,当前并行构建功能仅支持单个节点。
|
||||
|
||||
**请注意:在构建阶段执行安装命令`pip install transformers-stream-generator`**
|
||||
|
||||
以下是一些示例:
|
||||
|
||||
```bash
|
||||
# Build a single-XPU float16 engine from HF weights.
|
||||
# use_gpt_attention_plugin is necessary in Qwen.
|
||||
# Try use_gemm_plugin to prevent accuracy issue.
|
||||
# It is recommend to use --use_gpt_attention_plugin for better performance
|
||||
|
||||
# Build the Qwen 7B model using a single XPU and FP16.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
|
||||
|
||||
# Build the Qwen 7B model using a single XPU and apply INT8 weight-only quantization.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--use_weight_only \
|
||||
--weight_only_precision int8 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Build Qwen 7B using 2-way tensor parallelism.
|
||||
python build.py --hf_model_dir ./downloads/qwen-7b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2
|
||||
|
||||
|
||||
# Build Qwen 14B using 2-way tensor parallelism.
|
||||
python build.py --hf_model_dir ./downloads/qwen-14b/ \
|
||||
--dtype float16 \
|
||||
--use_gpt_attention_plugin float16 \
|
||||
--output_dir ./downloads/qwen-14b/trt_engines/fp16/2-XPU/ \
|
||||
--world_size 2 \
|
||||
--tp_size 2
|
||||
```
|
||||
|
||||
#### SmoothQuant
|
||||
|
||||
SmootQuant同时支持Qwen v1和Qwen v2。与FP16的HF权重可以直接被处理并加载到XTRT-LLM不同,SmoothQuant需要加载INT8权重,而INT8权重在构建引擎之前需要进行预处理。
|
||||
|
||||
示例:
|
||||
```bash
|
||||
python3 hf_qwen_convert.py -i ./downloads/qwen-7b/ -o ./downloads/qwen-7b/sq0.5/ -sq 0.5 --tensor-parallelism 1 --storage-type float16
|
||||
```
|
||||
|
||||
注意:`hf_qwen_convert.py`使用pytorch运行,并且
|
||||
1. 'torch-cpu' 通常比XPyTorch精度更高
|
||||
2. XPyTorch 通常使用超过32GB的GM,因此需要更多的XPU来完成它。
|
||||
3. 使用XPyTorch运行时,请添加`-p=1`。
|
||||
|
||||
`build.py`增加了新的选项来支持SmoothQuant模型的INT8推理。
|
||||
|
||||
`--use_smooth_quant` 是INT8推理的起点。默认情况下,它将以`--per-token`模式运行模型。
|
||||
`--per-token`和`--per-channel`目前还不支持。
|
||||
|
||||
构建调用示例:
|
||||
```bash
|
||||
# Build model for SmoothQuant in the _per_tensor_ mode.
|
||||
python3 build.py --ft_dir_path=./downloads/qwen-7b/sq0.5/1-XPU/ \
|
||||
--use_smooth_quant \
|
||||
--hf_model_dir ./downloads/qwen-7b/ \
|
||||
--output_dir ./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/
|
||||
```
|
||||
|
||||
- 运行
|
||||
```bash
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/
|
||||
```
|
||||
|
||||
- 总结
|
||||
```bash
|
||||
python ../summarize.py --test_trt_llm \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--data_type fp16 \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/sq0.5/1-XPU/ \
|
||||
--max_input_length 2048 \
|
||||
--output_len 2048
|
||||
```
|
||||
|
||||
|
||||
### 运行
|
||||
|
||||
**注意:在运行阶段执行安装命令`pip install tiktoken`**
|
||||
|
||||
要使用`build.py`生成的引擎运行XTRT-LLM Qwen模型,请执行以下操作:
|
||||
|
||||
```bash
|
||||
# With fp16 inference
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
|
||||
|
||||
# With int8 weight only inference
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir=./downloads/qwen-7b/trt_engines/int8_weight_only/1-XPU/
|
||||
|
||||
# Run Qwen 7B model in FP16 using two XPUs.
|
||||
mpirun -n 2 --allow-run-as-root \
|
||||
python ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--max_output_len=50 \
|
||||
--engine_dir ./downloads/qwen-7b/trt_engines/fp16/2-XPU/
|
||||
```
|
||||
|
||||
`run.py`的演示输出:
|
||||
|
||||
```bash
|
||||
python3 ../run.py --input_text "你好,请问你叫什么?" \
|
||||
--max_output_len=50 \
|
||||
--tokenizer_dir ./downloads/qwen-7b/ \
|
||||
--engine_dir ./downloads/qwen-7b/trt_engines/fp16/1-XPU/
|
||||
```
|
||||
```
|
||||
Loading engine from ./downloads/qwen-7b/trt_engines/fp16/1-XPU/qwen_float16_tp1_rank0.engine
|
||||
Input: "<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
你好,请问你叫什么?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"
|
||||
Output: "我是来自阿里云的大规模语言模型,我叫通义千问。"
|
||||
```
|
||||
BIN
examples/qwen/__pycache__/qwen_weight.cpython-38.pyc
Normal file
BIN
examples/qwen/__pycache__/qwen_weight.cpython-38.pyc
Normal file
Binary file not shown.
402
examples/qwen/benchmark.py
Normal file
402
examples/qwen/benchmark.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark offline inference throughput."""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm, trange
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from utils.utils import get_stop_words_ids, make_context
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm.runtime import ModelRunner, SamplingConfig
|
||||
|
||||
now_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
MAX_INPUT_LEN = 2048
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
TRT_MAX_BATCH_SIZE = 2
|
||||
TEMPERATURE = 1.0
|
||||
TOP_P = 0.5
|
||||
TOP_K = 1
|
||||
|
||||
|
||||
def sample_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
chat_format: str = "chatml",
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
tokenized_dataset = []
|
||||
for i in trange(len(dataset), desc="Tokenizing for sample"):
|
||||
prompt = dataset[i][0]
|
||||
output_text = dataset[i][1]
|
||||
raw_text, prompt_tokens = make_context(tokenizer=tokenizer,
|
||||
query=prompt,
|
||||
max_input_length=MAX_INPUT_LEN,
|
||||
chat_format=chat_format)
|
||||
new_token_len = len(tokenizer(output_text).input_ids)
|
||||
tokenized_dataset.append((raw_text, prompt_tokens, new_token_len))
|
||||
|
||||
# Filter out too long sequences.
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for prompt, prompt_token_ids, new_token_len in tokenized_dataset:
|
||||
prompt_len = len(prompt_token_ids)
|
||||
if prompt_len < 4 or new_token_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > MAX_INPUT_LEN or (prompt_len +
|
||||
new_token_len) > MAX_SEQ_LEN:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
# limit by MAX_SEQ_LEN
|
||||
filtered_dataset.append((prompt, prompt_len, new_token_len))
|
||||
|
||||
# Sample the requests.
|
||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def run_trt_llm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
engine_dir: str,
|
||||
tokenizer_dir: str,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
) -> float:
|
||||
global_max_input_len = MAX_INPUT_LEN
|
||||
global_max_output_len = MAX_SEQ_LEN
|
||||
if max_batch_size > TRT_MAX_BATCH_SIZE:
|
||||
raise Exception(
|
||||
"max batch size {} must be lower than trt_max_batch_size {}".format(
|
||||
max_batch_size, TRT_MAX_BATCH_SIZE))
|
||||
|
||||
# Ad hoc update to ModelRunner
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_dir,
|
||||
legacy=False,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
gen_config_path = os.path.join(tokenizer_dir, 'generation_config.json')
|
||||
with open(gen_config_path, 'r') as f:
|
||||
gen_config = json.load(f)
|
||||
top_k = gen_config['top_k']
|
||||
top_p = gen_config['top_p']
|
||||
chat_format = gen_config['chat_format']
|
||||
if chat_format == "raw":
|
||||
eos_token_id = gen_config['eos_token_id']
|
||||
pad_token_id = gen_config['pad_token_id']
|
||||
elif chat_format == "chatml":
|
||||
pad_token_id = eos_token_id = tokenizer.im_end_id
|
||||
else:
|
||||
raise Exception("unknown chat format ", chat_format)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
end_id=eos_token_id,
|
||||
pad_id=pad_token_id,
|
||||
num_beams=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
runner = ModelRunner.from_dir(engine_dir, rank=runtime_rank)
|
||||
decoder = runner.session
|
||||
|
||||
# Add the requests to the engine.
|
||||
sampling_config.num_beams = n
|
||||
sampling_config.temperature = 0.0 if n > 1 else TEMPERATURE
|
||||
sampling_config.top_p = TOP_P
|
||||
sampling_config.top_k = TOP_K
|
||||
start = time.time()
|
||||
pad_id = tokenizer.im_end_id
|
||||
|
||||
batch: List[str] = []
|
||||
max_new_tokens = 0
|
||||
total_num_tokens = []
|
||||
for i, (prompt, prompt_len, new_token_len) in tqdm(enumerate(requests),
|
||||
total=len(requests)):
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_new_tokens = max(max_new_tokens, new_token_len)
|
||||
if len(batch) < max_batch_size and i < len(requests) - 1:
|
||||
continue
|
||||
input_ids = []
|
||||
input_lengths = []
|
||||
for input_text in batch:
|
||||
input_id = tokenizer(
|
||||
input_text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=global_max_input_len,
|
||||
).input_ids.type(torch.int32)
|
||||
input_ids.append(input_id)
|
||||
input_lengths.append(input_id.shape[-1])
|
||||
# padding
|
||||
max_length = max(input_lengths)
|
||||
# do padding, should move outside the profiling to prevent the overhead
|
||||
for i in range(len(input_ids)):
|
||||
pad_size = max_length - input_lengths[i]
|
||||
|
||||
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
|
||||
input_ids[i] = torch.cat([torch.IntTensor(input_ids[i]), pad],
|
||||
axis=-1)
|
||||
# do inference
|
||||
input_ids = torch.cat(input_ids, axis=0).cuda()
|
||||
input_lengths = torch.IntTensor(input_lengths).type(torch.int32).cuda()
|
||||
output_ids = decoder.generate(
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
sampling_config=sampling_config,
|
||||
max_new_tokens=min(max_new_tokens,
|
||||
global_max_output_len - input_ids.shape[1]),
|
||||
)
|
||||
pure_output_ids = []
|
||||
for i in range(len(batch)):
|
||||
temp_ids = output_ids[i, input_lengths[i]:]
|
||||
pure_ids = []
|
||||
for i in range(len(temp_ids)):
|
||||
if temp_ids[i] in [tokenizer.im_start_id, tokenizer.im_end_id]:
|
||||
pure_ids = temp_ids[:i + 1]
|
||||
break
|
||||
if len(pure_ids) == 0:
|
||||
pure_ids = temp_ids
|
||||
pure_output_ids.append(pure_ids)
|
||||
# get the output text
|
||||
output_texts = [
|
||||
tokenizer.decode(out_ids, skip_special_tokens=True)
|
||||
for out_ids in pure_output_ids
|
||||
]
|
||||
# get the total num of tokens
|
||||
output_lengths = [len(out_ids) for out_ids in pure_output_ids]
|
||||
assert len(output_lengths) == len(batch)
|
||||
for input_len, new_token_len in zip(input_lengths, output_lengths):
|
||||
total_num_tokens.append(input_len + new_token_len)
|
||||
batch = []
|
||||
max_new_tokens = 0
|
||||
|
||||
end = time.time()
|
||||
during = end - start
|
||||
sum_total_num_tokens = sum(total_num_tokens)
|
||||
return during, sum_total_num_tokens
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
chat_format: str = "chatml",
|
||||
) -> float:
|
||||
global_max_input_len = MAX_INPUT_LEN
|
||||
global_max_output_len = MAX_SEQ_LEN
|
||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif llm.config.model_type == "qwen":
|
||||
tokenizer.pad_token = tokenizer.decode(tokenizer.im_end_id)
|
||||
llm = llm.cuda()
|
||||
stop_words_ids = []
|
||||
stop_words_ids.extend(get_stop_words_ids(chat_format, tokenizer))
|
||||
stop_words_ids2 = [idx for ids in stop_words_ids for idx in ids]
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.time()
|
||||
total_num_tokens = []
|
||||
batch: List[str] = []
|
||||
input_lengths: List[int] = []
|
||||
max_prompt_len = 0
|
||||
max_new_tokens = 0
|
||||
for i in range(len(requests)):
|
||||
prompt, prompt_len, new_token_len = requests[i]
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
input_lengths.append(prompt_len)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_new_tokens = max(max_new_tokens, new_token_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
temp_input_max = max(max_prompt_len, next_prompt_len)
|
||||
temp_new_token_max = max(max_new_tokens, next_output_len)
|
||||
if temp_input_max <= global_max_input_len and \
|
||||
(temp_input_max + temp_new_token_max) <= global_max_output_len:
|
||||
continue
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(
|
||||
batch,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=global_max_input_len,
|
||||
).input_ids
|
||||
|
||||
# limit the max_new_tokens
|
||||
max_new_tokens = min(max_new_tokens,
|
||||
global_max_output_len - input_ids.shape[1])
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
stop_words_ids=stop_words_ids,
|
||||
num_return_sequences=n,
|
||||
top_k=TOP_K,
|
||||
top_p=TOP_P,
|
||||
temperature=TEMPERATURE,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
pure_output_ids = llm_outputs[:, input_ids.shape[-1]:]
|
||||
# get the output text
|
||||
output_texts = tokenizer.batch_decode(pure_output_ids,
|
||||
skip_special_tokens=True)
|
||||
output_lengths = []
|
||||
for out_ids in pure_output_ids:
|
||||
early_stop = False
|
||||
for i in range(len(out_ids)):
|
||||
if out_ids[i] in stop_words_ids2:
|
||||
output_lengths.append(i + 1)
|
||||
early_stop = True
|
||||
break
|
||||
if not early_stop:
|
||||
output_lengths.append(len(out_ids))
|
||||
assert len(output_lengths) == len(batch)
|
||||
for input_len, new_token_len in zip(input_lengths, output_lengths):
|
||||
total_num_tokens.append(input_len + new_token_len)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
input_lengths = []
|
||||
max_prompt_len = 0
|
||||
max_new_tokens = 0
|
||||
end = time.time()
|
||||
during = end - start
|
||||
sum_total_num_tokens = sum(total_num_tokens)
|
||||
return during, sum_total_num_tokens
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir,
|
||||
padding_side='left',
|
||||
trust_remote_code=True,
|
||||
)
|
||||
requests = sample_requests(tokenizer=tokenizer,
|
||||
dataset_path=args.dataset,
|
||||
num_requests=args.num_prompts,
|
||||
chat_format=args.chat_format)
|
||||
|
||||
if args.backend == "trt_llm":
|
||||
elapsed_time, total_num_tokens = run_trt_llm(
|
||||
requests=requests,
|
||||
engine_dir=args.engine_dir,
|
||||
tokenizer_dir=args.tokenizer_dir,
|
||||
n=args.n,
|
||||
max_batch_size=args.trt_max_batch_size,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
elapsed_time, total_num_tokens = run_hf(
|
||||
requests=requests,
|
||||
model=args.hf_model_dir,
|
||||
tokenizer=tokenizer,
|
||||
n=args.n,
|
||||
max_batch_size=args.hf_max_batch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["trt_llm", "hf"],
|
||||
default="trt_llm",
|
||||
)
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=os.path.join(
|
||||
now_dir,
|
||||
"ShareGPT_V3_unfiltered_cleaned_split.json"),
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--hf_model_dir", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Directory containing the tokenizer.model.")
|
||||
parser.add_argument('--engine_dir', type=str, default='qwen_outputs')
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf_max_batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum batch size for HF backend.")
|
||||
|
||||
parser.add_argument("--trt_max_batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum batch size for TRT-LLM backend.")
|
||||
parser.add_argument("--chat-format",
|
||||
type=str,
|
||||
default="chatml",
|
||||
choices=["chatml", "raw"],
|
||||
help="choice the model format, base or chat")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "trt-llm":
|
||||
if args.trt_max_batch_size is None:
|
||||
raise ValueError(
|
||||
"trt max batch size is required for TRT-LLM backend.")
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("hf max batch size is required for HF backend.")
|
||||
if args.tokenizer_dir is None:
|
||||
args.tokenizer_dir = args.hf_model
|
||||
|
||||
main(args)
|
||||
727
examples/qwen/build.py
Normal file
727
examples/qwen/build.py
Normal file
@@ -0,0 +1,727 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import tvm.tensorrt as trt
|
||||
# isort: on
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
try:
|
||||
from transformers import Qwen2ForCausalLM
|
||||
except ImportError:
|
||||
print(
|
||||
"Qwen1.5 requires transformers>=4.37.1, type pip install transformers==4.37.1"
|
||||
)
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_xtrt
|
||||
from xtrt_llm.builder import Builder
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import quantize_model
|
||||
from xtrt_llm.network import net_guard
|
||||
from xtrt_llm.plugin.plugin import ContextFMHAType
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
MODEL_NAME = "qwen"
|
||||
|
||||
import onnx
|
||||
import tvm.tensorrt as trt
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
now_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def trt_dtype_to_onnx(dtype):
|
||||
if dtype == trt.float16:
|
||||
return TensorProto.DataType.FLOAT16
|
||||
elif dtype == trt.float32:
|
||||
return TensorProto.DataType.FLOAT
|
||||
elif dtype == trt.int32:
|
||||
return TensorProto.DataType.INT32
|
||||
else:
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
def to_onnx(network, path):
|
||||
inputs = []
|
||||
for i in range(network.num_inputs):
|
||||
network_input = network.get_input(i)
|
||||
inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
||||
list(network_input.shape)))
|
||||
|
||||
outputs = []
|
||||
for i in range(network.num_outputs):
|
||||
network_output = network.get_output(i)
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
||||
list(network_output.shape)))
|
||||
|
||||
nodes = []
|
||||
for i in range(network.num_layers):
|
||||
layer = network.get_layer(i)
|
||||
layer_inputs = []
|
||||
for j in range(layer.num_inputs):
|
||||
ipt = layer.get_input(j)
|
||||
if ipt is not None:
|
||||
layer_inputs.append(layer.get_input(j).name)
|
||||
layer_outputs = [
|
||||
layer.get_output(j).name for j in range(layer.num_outputs)
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(str(layer.type),
|
||||
name=layer.name,
|
||||
inputs=layer_inputs,
|
||||
outputs=layer_outputs,
|
||||
domain="com.nvidia"))
|
||||
|
||||
onnx_model = helper.make_model(helper.make_graph(nodes,
|
||||
'attention',
|
||||
inputs,
|
||||
outputs,
|
||||
initializer=None),
|
||||
producer_name='NVIDIA')
|
||||
onnx.save(onnx_model, path)
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, pp_size, rank):
|
||||
if pp_size == 1:
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
|
||||
pp_size, rank)
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
engine.serialize(path)
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--world_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="world size, only support tensor parallelism now",
|
||||
)
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--hf_model_dir", type=str, default=None)
|
||||
parser.add_argument("--version",
|
||||
"-v",
|
||||
type=str,
|
||||
default="1",
|
||||
help="qwen version, support 1, 1.5")
|
||||
parser.add_argument("--ft_dir_path", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float32", "bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
default='model.cache',
|
||||
help=
|
||||
'The path of to read timing cache from, will be ignored if the file does not exist'
|
||||
)
|
||||
parser.add_argument('--log_level',
|
||||
type=str,
|
||||
default='info',
|
||||
choices=[
|
||||
'internal_error',
|
||||
'error',
|
||||
'warning',
|
||||
'info',
|
||||
'verbose',
|
||||
])
|
||||
parser.add_argument('--vocab_size', type=int, default=32000)
|
||||
parser.add_argument('--n_layer', type=int, default=32)
|
||||
parser.add_argument('--n_positions', type=int, default=2048)
|
||||
parser.add_argument('--n_embd', type=int, default=4096)
|
||||
parser.add_argument('--n_head', type=int, default=32)
|
||||
parser.add_argument('--n_kv_head', type=int, default=None)
|
||||
parser.add_argument('--inter_size', type=int, default=11008)
|
||||
parser.add_argument('--hidden_act', type=str, default='silu')
|
||||
parser.add_argument('--max_batch_size', type=int, default=2)
|
||||
parser.add_argument('--max_input_len', type=int, default=2048)
|
||||
parser.add_argument('--max_output_len', type=int, default=2048)
|
||||
parser.add_argument('--max_beam_width', type=int, default=1)
|
||||
parser.add_argument('--rotary_base', type=float, default=10000.0)
|
||||
parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None)
|
||||
parser.add_argument('--use_gpt_attention_plugin',
|
||||
nargs='?',
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=['float16', 'bfloat16', 'float32', None])
|
||||
parser.add_argument('--use_gemm_plugin',
|
||||
nargs='?',
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=['float16', 'bfloat16', 'float32', None])
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument('--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='engine_outputs',
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
parser.add_argument('--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true')
|
||||
# Arguments related to the quantization of the model.
|
||||
parser.add_argument(
|
||||
'--use_smooth_quant',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
|
||||
'See --per_channel and --per_token for finer-grained quantization options.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.')
|
||||
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_inflight_batching',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.")
|
||||
parser.add_argument(
|
||||
'--paged_kv_cache',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
|
||||
)
|
||||
parser.add_argument('--tokens_per_block',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Number of tokens per block in paged KV cache')
|
||||
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine')
|
||||
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_parallel_embedding',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--embedding_sharding_dim',
|
||||
type=int,
|
||||
default=1, # Meta does TP on hidden dim
|
||||
choices=[0, 1],
|
||||
help=
|
||||
'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). '
|
||||
'To shard it along hidden dimension, set embedding_sharding_dim=1'
|
||||
'Note: embedding sharing is only enabled when embedding_sharding_dim = 0'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opt_memory_use',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Whether to use Host memory optimization for building engine')
|
||||
parser.add_argument(
|
||||
'--use_custom_all_reduce',
|
||||
action='store_true',
|
||||
help=
|
||||
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
|
||||
parser.add_argument('--gather_all_token_logits',
|
||||
action='store_true',
|
||||
default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
assert not (
|
||||
args.use_smooth_quant and args.use_weight_only
|
||||
), "You cannot enable both SmoothQuant and INT8 weight-only together."
|
||||
|
||||
if not args.remove_input_padding:
|
||||
if args.use_gpt_attention_plugin:
|
||||
logger.warning(
|
||||
f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
|
||||
)
|
||||
|
||||
if args.use_inflight_batching:
|
||||
if not args.use_gpt_attention_plugin:
|
||||
args.use_gpt_attention_plugin = 'float16'
|
||||
logger.info(
|
||||
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
|
||||
)
|
||||
if not args.remove_input_padding:
|
||||
args.remove_input_padding = True
|
||||
logger.info(
|
||||
"Using remove input padding for inflight batching mode.")
|
||||
if not args.paged_kv_cache:
|
||||
args.paged_kv_cache = True
|
||||
logger.info("Using paged KV cache for inflight batching mode.")
|
||||
|
||||
if args.use_smooth_quant:
|
||||
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
|
||||
args.per_channel)
|
||||
elif args.use_weight_only:
|
||||
if args.per_group:
|
||||
args.quant_mode = QuantMode.from_description(
|
||||
quantize_weights=True,
|
||||
quantize_activations=False,
|
||||
per_token=False,
|
||||
per_channel=False,
|
||||
per_group=True,
|
||||
use_int4_weights=True)
|
||||
else:
|
||||
args.quant_mode = QuantMode.use_weight_only(
|
||||
args.weight_only_precision == 'int4')
|
||||
else:
|
||||
args.quant_mode = QuantMode(0)
|
||||
|
||||
if args.int8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_int8_kv_cache()
|
||||
|
||||
if args.hf_model_dir is not None:
|
||||
hf_config = AutoConfig.from_pretrained(
|
||||
args.hf_model_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
args.inter_size = hf_config.intermediate_size # override the inter_size for QWen
|
||||
args.n_embd = hf_config.hidden_size
|
||||
args.n_head = hf_config.num_attention_heads
|
||||
if hasattr(hf_config, "num_key_value_heads"):
|
||||
args.n_kv_head = hf_config.num_key_value_heads
|
||||
args.n_layer = hf_config.num_hidden_layers
|
||||
args.n_positions = hf_config.max_position_embeddings
|
||||
args.vocab_size = hf_config.vocab_size
|
||||
args.hidden_act = "silu"
|
||||
if hasattr(hf_config, "kv_channels"):
|
||||
args.kv_channels = hf_config.kv_channels
|
||||
elif hasattr(hf_config, "num_key_value_heads"):
|
||||
args.kv_channels = hf_config.num_key_value_heads
|
||||
else:
|
||||
raise
|
||||
if hasattr(hf_config, "rotary_emb_base"):
|
||||
args.rotary_emb_base = hf_config.rotary_emb_base
|
||||
else:
|
||||
args.rotary_emb_base = 10000.0
|
||||
assert args.use_gpt_attention_plugin is not None, "QWen must use gpt attention plugin"
|
||||
# if args.n_kv_head is not None and args.n_kv_head != args.n_head:
|
||||
# assert (args.n_head % args.n_kv_head) == 0, \
|
||||
# "MQA/GQA requires the number of heads to be divisible by the number of K/V heads."
|
||||
# assert args.n_kv_head == args.tp_size, \
|
||||
# "The current implementation of GQA requires the number of K/V heads to match the number of GPUs." \
|
||||
# "This limitation will be removed in a future version."
|
||||
|
||||
assert args.pp_size * args.tp_size == args.world_size
|
||||
|
||||
if args.max_num_tokens is not None:
|
||||
assert args.enable_context_fmha
|
||||
|
||||
assert (math.log2(args.tokens_per_block).is_integer()
|
||||
), "tokens_per_block must be power of 2"
|
||||
if args.enable_context_fmha or args.enable_context_fmha_fp32_acc:
|
||||
assert (args.tokens_per_block >=
|
||||
128), "Context fMHA requires >= 128 tokens per block"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_rank_engine(builder: Builder,
|
||||
builder_config: xtrt_llm.builder.BuilderConfig,
|
||||
engine_name, rank, multi_query_mode, args):
|
||||
'''
|
||||
@brief: Build the engine on the given rank.
|
||||
@param rank: The rank to build the engine.
|
||||
@param args: The cmd line arguments.
|
||||
@return: The built engine.
|
||||
'''
|
||||
kv_dtype = str_dtype_to_xtrt(args.dtype)
|
||||
mapping = Mapping(world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.tp_size,
|
||||
pp_size=args.pp_size)
|
||||
|
||||
# Initialize Module
|
||||
assert args.version in ["1", "1.5"], "Only support version 1 and 1.5"
|
||||
if args.version == "1.5":
|
||||
from qwen2_weight import load_from_ft, load_from_hf_qwen
|
||||
|
||||
xtrt_llm_qwen = xtrt_llm.models.Qwen2ForCausalLM(
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
num_kv_heads=args.n_kv_head,
|
||||
hidden_size=args.n_embd,
|
||||
seq_length=args.max_input_len,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=args.inter_size,
|
||||
mapping=mapping,
|
||||
rotary_base=args.rotary_base,
|
||||
rotary_scaling=args.rotary_scaling,
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
embedding_sharding_dim=args.embedding_sharding_dim,
|
||||
quant_mode=args.quant_mode,
|
||||
gather_all_token_logits=args.gather_all_token_logits,
|
||||
)
|
||||
else:
|
||||
from qwen_weight import load_from_ft, load_from_hf_qwen
|
||||
|
||||
xtrt_llm_qwen = xtrt_llm.models.QWenForCausalLM(
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
num_kv_heads=args.n_kv_head,
|
||||
hidden_size=args.n_embd,
|
||||
seq_length=args.max_input_len,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=args.inter_size,
|
||||
neox_rotary_style=True,
|
||||
mapping=mapping,
|
||||
rotary_base=args.rotary_base,
|
||||
rotary_scaling=args.rotary_scaling,
|
||||
use_parallel_embedding=args.use_parallel_embedding,
|
||||
embedding_sharding_dim=args.embedding_sharding_dim,
|
||||
quant_mode=args.quant_mode,
|
||||
gather_all_token_logits=args.gather_all_token_logits,
|
||||
)
|
||||
|
||||
quantize_kwargs = {}
|
||||
if args.use_smooth_quant or args.use_weight_only:
|
||||
if args.weight_only_precision == 'int4_awq':
|
||||
quantize_kwargs = {
|
||||
"group_size": args.group_size,
|
||||
"zero": False,
|
||||
"pre_quant_scale": True,
|
||||
"exclude_modules": [],
|
||||
}
|
||||
elif args.weight_only_precision == 'int4_gptq':
|
||||
quantize_kwargs = {
|
||||
"group_size": args.group_size,
|
||||
"zero": True,
|
||||
"pre_quant_scale": False,
|
||||
}
|
||||
xtrt_llm_qwen = quantize_model(xtrt_llm_qwen, args.quant_mode,
|
||||
**quantize_kwargs)
|
||||
ft_dir_path = args.ft_dir_path
|
||||
if args.hf_model_dir is not None and \
|
||||
(ft_dir_path is None or not os.path.exists(ft_dir_path)):
|
||||
logger.info(f'Loading HF QWen ... from {args.hf_model_dir}')
|
||||
tik = time.time()
|
||||
|
||||
if args.version == "1":
|
||||
hf_qwen = AutoModelForCausalLM.from_pretrained(
|
||||
args.hf_model_dir,
|
||||
device_map={
|
||||
"transformer": "cpu",
|
||||
"lm_head": "cpu",
|
||||
}, # Load to CPU memory
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
else:
|
||||
hf_qwen = Qwen2ForCausalLM.from_pretrained(
|
||||
args.hf_model_dir,
|
||||
# device_map="cpu",
|
||||
device_map={
|
||||
"model": "cpu",
|
||||
"lm_head": "cpu"
|
||||
}, # Load to CPU memory
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'HF QWen loaded. Total time: {t}')
|
||||
load_from_hf_qwen(xtrt_llm_qwen,
|
||||
hf_qwen,
|
||||
mapping,
|
||||
max_position_embeddings=args.n_positions,
|
||||
kv_channels=args.kv_channels,
|
||||
rotary_emb_base=args.rotary_emb_base,
|
||||
dtype=args.dtype,
|
||||
multi_query_mode=multi_query_mode)
|
||||
del hf_qwen
|
||||
elif ft_dir_path is not None:
|
||||
dir_path = ft_dir_path
|
||||
logger.info(f'Loading FT QWen ... from {ft_dir_path}')
|
||||
load_from_ft(xtrt_llm_qwen,
|
||||
dir_path,
|
||||
mapping,
|
||||
dtype=args.dtype,
|
||||
multi_query_mode=multi_query_mode)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must specify either --hf_model_dir or --ft_dir_path")
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
if args.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.use_weight_only:
|
||||
if args.per_group:
|
||||
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
else:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
if args.quant_mode.is_weight_only():
|
||||
builder_config.trt_builder_config.use_weight_only = args.weight_only_precision
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
args.use_custom_all_reduce)
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
|
||||
if args.paged_kv_cache:
|
||||
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(xtrt_llm_qwen.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = xtrt_llm_qwen.prepare_inputs(
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_new_tokens=args.max_output_len,
|
||||
use_cache=True,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
)
|
||||
xtrt_llm_qwen(*inputs)
|
||||
if args.enable_debug_output:
|
||||
# mark intermediate nodes' outputs
|
||||
for k, v in xtrt_llm_qwen.named_network_outputs():
|
||||
v = v.trt_tensor
|
||||
v.name = k
|
||||
network.trt_network.mark_output(v)
|
||||
v.dtype = kv_dtype
|
||||
if args.visualize:
|
||||
model_path = os.path.join(args.output_dir, 'test.onnx')
|
||||
to_onnx(network.trt_network, model_path)
|
||||
|
||||
engine = None
|
||||
|
||||
# Network -> Engine
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
if rank == 0:
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
if args.opt_memory_use:
|
||||
return engine, network
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
xtrt_llm.logger.set_level(args.log_level)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
multi_query_mode = (args.n_kv_head
|
||||
is not None) and (args.n_kv_head != args.n_head)
|
||||
|
||||
# when doing serializing build, all ranks share one engine
|
||||
builder = Builder()
|
||||
|
||||
cache = None
|
||||
for cur_rank in range(args.world_size):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
int8_trt_flag = args.quant_mode.has_act_and_weight_quant() or (
|
||||
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
|
||||
builder_config = builder.create_builder_config(
|
||||
name=MODEL_NAME,
|
||||
precision=args.dtype,
|
||||
timing_cache=args.timing_cache if cache is None else cache,
|
||||
tensor_parallel=args.tp_size,
|
||||
pipeline_parallel=args.pp_size,
|
||||
parallel_build=args.parallel_build,
|
||||
num_layers=args.n_layer,
|
||||
num_heads=args.n_head,
|
||||
hidden_size=args.n_embd,
|
||||
inter_size=args.inter_size,
|
||||
vocab_size=args.vocab_size,
|
||||
hidden_act=args.hidden_act,
|
||||
max_position_embeddings=args.n_positions,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
fusion_pattern_list=["remove_dup_mask"],
|
||||
int8=int8_trt_flag,
|
||||
fp8=args.quant_mode.has_fp8_qdq(),
|
||||
quant_mode=args.quant_mode,
|
||||
strongly_typed=args.strongly_typed,
|
||||
opt_level=args.builder_opt,
|
||||
max_prompt_embedding_table_size=0,
|
||||
# max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
|
||||
gather_all_token_logits=args.gather_all_token_logits)
|
||||
guard = xtrt_llm.fusion_patterns.FuseonPatternGuard()
|
||||
print(guard)
|
||||
engine_name = get_engine_name(MODEL_NAME, args.dtype, args.tp_size,
|
||||
args.pp_size, cur_rank)
|
||||
if args.opt_memory_use:
|
||||
engine, network = build_rank_engine(builder, builder_config,
|
||||
engine_name, cur_rank,
|
||||
multi_query_mode, args)
|
||||
else:
|
||||
engine = build_rank_engine(builder, builder_config, engine_name,
|
||||
cur_rank, multi_query_mode, args)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
|
||||
if cur_rank == 0:
|
||||
# Use in-memory timing cache for multiple builder passes.
|
||||
if not args.parallel_build:
|
||||
cache = builder_config.trt_builder_config.get_timing_cache()
|
||||
|
||||
serialize_engine(engine, os.path.join(args.output_dir, engine_name))
|
||||
|
||||
del engine
|
||||
if args.opt_memory_use:
|
||||
network.__del__()
|
||||
|
||||
# if rank == 0:
|
||||
# ok = builder.save_timing_cache(
|
||||
# builder_config, os.path.join(args.output_dir, "model.cache"))
|
||||
# assert ok, "Failed to save timing cache."
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
logger.set_level(args.log_level)
|
||||
tik = time.time()
|
||||
if args.version == "1.5":
|
||||
MODEL_NAME = 'qwen2'
|
||||
if args.parallel_build and args.world_size > 1 and \
|
||||
torch.cuda.device_count() >= args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
logger.info('Serially build TensorRT engines.')
|
||||
build(0, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Total time of building all {args.world_size} engines: {t}')
|
||||
361
examples/qwen/hf_qwen_convert.py
Normal file
361
examples/qwen/hf_qwen_convert.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Convert huggingface QWen-7B-Chat model to numpy file.
|
||||
Use https://huggingface.co/Qwen/Qwen-7B-Chat as demo.
|
||||
'''
|
||||
import argparse
|
||||
import configparser
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as multiprocessing
|
||||
from smoothquant import capture_activation_range, smooth_gemm, smooth_gemm_mlp
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM # transformers-4.10.0-py3
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
# for debug
|
||||
from utils.convert import split_and_save_weight
|
||||
|
||||
from xtrt_llm._utils import str_dtype_to_torch, torch_to_numpy
|
||||
|
||||
now_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ProgArgs:
|
||||
out_dir: str
|
||||
in_file: str
|
||||
max_input_len: int = 2048
|
||||
tensor_parallelism: int = 1
|
||||
processes: int = 1
|
||||
calibrate_kv_cache: bool = False
|
||||
smoothquant: float = None
|
||||
model: str = "qwen"
|
||||
storage_type: str = "fp32"
|
||||
dataset_cache_dir: str = None
|
||||
|
||||
@staticmethod
|
||||
def parse(args=None) -> 'ProgArgs':
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser.add_argument('--out-dir',
|
||||
'-o',
|
||||
type=str,
|
||||
help='file name of output directory',
|
||||
required=True)
|
||||
parser.add_argument('--in-file',
|
||||
'-i',
|
||||
type=str,
|
||||
help='file name of input checkpoint file',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--max_input_len',
|
||||
type=int,
|
||||
help=
|
||||
"This should be consistent with the max_input_len you used when building engine.",
|
||||
default=2048)
|
||||
parser.add_argument('--tensor-parallelism',
|
||||
'-tp',
|
||||
type=int,
|
||||
help='Requested tensor parallelism for inference',
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
"-p",
|
||||
type=int,
|
||||
help=
|
||||
"How many processes to spawn for conversion (default: 1). Set it to a lower value to reduce RAM usage.",
|
||||
default=1)
|
||||
parser.add_argument(
|
||||
"--calibrate-kv-cache",
|
||||
"-kv",
|
||||
action="store_true",
|
||||
help=
|
||||
"Generate scaling factors for KV cache. Used for storing KV cache in int8."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--smoothquant",
|
||||
"-sq",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
|
||||
" to Smoothquant the model, and output int8 weights."
|
||||
" A good first try is 0.5. Must be in [0, 1]")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="qwen",
|
||||
type=str,
|
||||
help="Specify GPT variants to convert checkpoints correctly",
|
||||
choices=["qwen", "gpt2", "santacoder", "starcoder"])
|
||||
parser.add_argument("--storage-type",
|
||||
"-t",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float32", "float16", "bfloat16"])
|
||||
parser.add_argument("--dataset-cache-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="cache dir to load the hugging face dataset")
|
||||
return ProgArgs(**vars(parser.parse_args(args)))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_qwen_model(model, scales, alpha, qwen_smoother):
|
||||
# Smooth the activation and weights with smoother = $\diag{s}$
|
||||
for name, module in model.named_modules():
|
||||
# if not isinstance(module, QWenBlock):
|
||||
if not str(type(module)).endswith("QWenBlock'>"):
|
||||
continue
|
||||
|
||||
# qkv_proj
|
||||
layer_name = name + ".attn.c_attn"
|
||||
smoother = smooth_gemm(module.attn.c_attn.weight,
|
||||
scales[layer_name]["x"],
|
||||
module.ln_1.weight,
|
||||
alpha=alpha)
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
|
||||
scales[layer_name]["w"] = module.attn.c_attn.weight.abs().max(dim=1)[0]
|
||||
|
||||
# attention dense
|
||||
layer_name = name + ".attn.c_proj"
|
||||
smoother3 = smooth_gemm(
|
||||
module.attn.c_proj.weight,
|
||||
scales[layer_name]["x"],
|
||||
None,
|
||||
alpha=alpha,
|
||||
)
|
||||
qwen_smoother[layer_name] = smoother3.float()
|
||||
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother3
|
||||
scales[layer_name]["w"] = module.attn.c_proj.weight.abs().max(dim=1)[0]
|
||||
|
||||
# mlp w1 / w2, because then use some input hidden_states as input, so we need to smooth it with same scale
|
||||
mlp_w1_name = name + ".mlp.w1"
|
||||
mlp_w2_name = name + ".mlp.w2"
|
||||
smoother2 = smooth_gemm_mlp(module.mlp.w1.weight,
|
||||
module.mlp.w2.weight,
|
||||
scales[mlp_w1_name]["x"],
|
||||
module.ln_2.weight,
|
||||
alpha=alpha)
|
||||
scales[mlp_w1_name]["x"] = scales[mlp_w1_name]["x"] / smoother2
|
||||
scales[mlp_w2_name]["x"] = scales[mlp_w2_name]["x"] / smoother2
|
||||
scales[mlp_w1_name]["w"] = module.mlp.w1.weight.abs().max(dim=1)[0]
|
||||
scales[mlp_w2_name]["w"] = module.mlp.w2.weight.abs().max(dim=1)[0]
|
||||
|
||||
# mlp c_proj
|
||||
layer_name = name + ".mlp.c_proj"
|
||||
smoother4 = smooth_gemm(module.mlp.c_proj.weight,
|
||||
scales[layer_name]["x"],
|
||||
None,
|
||||
alpha=alpha)
|
||||
qwen_smoother[layer_name] = smoother4.float()
|
||||
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother4
|
||||
scales[layer_name]["w"] = module.mlp.c_proj.weight.abs().max(dim=1)[0]
|
||||
|
||||
|
||||
# SantaCoder separates Q projection from KV projection
|
||||
def concat_qkv_weight_bias(q, hf_key, hf_model):
|
||||
kv = hf_model.state_dict()[hf_key.replace("q_attn", "kv_attn")]
|
||||
return torch.cat([q, kv], dim=-1)
|
||||
|
||||
|
||||
# StarCoder uses nn.Linear for these following ops whose weight matrix is transposed compared to transformer.Conv1D
|
||||
def transpose_weights(hf_name, param):
|
||||
weight_to_transpose = [
|
||||
"attn.c_attn", "attn.c_proj", "mlp.c_proj", "mlp.w1", "mlp.w2"
|
||||
]
|
||||
if any([k in hf_name for k in weight_to_transpose]):
|
||||
if len(param.shape) == 2:
|
||||
param = param.transpose(0, 1)
|
||||
return param
|
||||
|
||||
|
||||
def convert_qwen_name(orig_name):
|
||||
global_weights = {
|
||||
"transformer.wte.weight": "vocab_embedding.weight",
|
||||
"transformer.ln_f.weight": "ln_f.weight",
|
||||
"lm_head.weight": "lm_head.weight"
|
||||
}
|
||||
|
||||
if orig_name in global_weights:
|
||||
return global_weights[orig_name]
|
||||
|
||||
_, _, layer_id, *weight_name = orig_name.split(".")
|
||||
layer_id = int(layer_id)
|
||||
weight_name = "transformer." + ".".join(weight_name)
|
||||
|
||||
per_layer_weights = {
|
||||
"transformer.ln_1.weight": "ln_1.weight",
|
||||
"transformer.ln_2.weight": "ln_2.weight",
|
||||
"transformer.attn.c_attn.weight": "attention.qkv.weight",
|
||||
"transformer.attn.c_attn.bias": "attention.qkv.bias",
|
||||
"transformer.attn.c_proj.weight": "attention.dense.weight",
|
||||
"transformer.mlp.w1.weight": "mlp.w1.weight",
|
||||
"transformer.mlp.w2.weight": "mlp.w2.weight",
|
||||
"transformer.mlp.c_proj.weight": "mlp.c_proj.weight",
|
||||
}
|
||||
return f"layers.{layer_id}.{per_layer_weights[weight_name]}"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def hf_qwen_converter(args: ProgArgs):
|
||||
infer_tp = args.tensor_parallelism
|
||||
multi_query_mode = True if args.model in ["santacoder", "starcoder"
|
||||
] else False
|
||||
saved_dir = Path(args.out_dir) / f"{infer_tp}-XPU"
|
||||
saved_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load position_embedding from rank 0
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.in_file,
|
||||
device_map=
|
||||
"auto", # if you gpu memory is not enough, you can set device_map="cpu"
|
||||
trust_remote_code=True,
|
||||
torch_dtype=str_dtype_to_torch(args.storage_type),
|
||||
).float() # if you gpu memory is not enough, you can set .half() to .float()
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
args.in_file, trust_remote_code=True)
|
||||
act_range = {}
|
||||
qwen_smoother = {}
|
||||
if args.smoothquant is not None or args.calibrate_kv_cache:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
|
||||
"TOKENIZERS_PARALLELISM", "false")
|
||||
from datasets import load_dataset
|
||||
|
||||
# copy from summarize.py
|
||||
dataset_cnn = load_dataset("ccdv/cnn_dailymail", '3.0.0')
|
||||
dataset = dataset_cnn["test"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.in_file,
|
||||
legacy=False,
|
||||
padding_side='left',
|
||||
trust_remote_code=True,
|
||||
)
|
||||
gen_config_path = os.path.join(args.in_file, 'generation_config.json')
|
||||
with open(gen_config_path, 'r') as f:
|
||||
gen_config = json.load(f)
|
||||
chat_format = gen_config['chat_format']
|
||||
tokenizer.pad_token_id = tokenizer.im_end_id
|
||||
# use this prompt to make chat model do summarize
|
||||
system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
|
||||
act_range = capture_activation_range(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset,
|
||||
system_prompt=system_prompt,
|
||||
chat_format=chat_format,
|
||||
max_input_len=args.max_input_len,
|
||||
)
|
||||
if args.smoothquant is not None:
|
||||
smooth_qwen_model(model, act_range, args.smoothquant, qwen_smoother)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config["qwen"] = {}
|
||||
for key in vars(args):
|
||||
config["qwen"][key] = f"{vars(args)[key]}"
|
||||
for k, v in vars(model.config).items():
|
||||
config["qwen"][k] = f"{v}"
|
||||
config["qwen"]["storage_dtype"] = args.storage_type
|
||||
config["qwen"]["multi_query_mode"] = str(multi_query_mode)
|
||||
with open(saved_dir / "config.ini", 'w') as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
storage_type = str_dtype_to_torch(args.storage_type)
|
||||
|
||||
global_weights = ["vocab_embedding.weight", "ln_f.weight", "lm_head.weight"]
|
||||
|
||||
int8_outputs = None
|
||||
if args.calibrate_kv_cache:
|
||||
int8_outputs = "kv_cache_only"
|
||||
if args.smoothquant is not None:
|
||||
int8_outputs = "all"
|
||||
|
||||
starmap_args = []
|
||||
for name, param in tqdm(
|
||||
model.named_parameters(),
|
||||
desc="convert and save",
|
||||
total=len(list(model.parameters())),
|
||||
ncols=80,
|
||||
):
|
||||
if "weight" not in name and "bias" not in name:
|
||||
continue
|
||||
converted_name = convert_qwen_name(name)
|
||||
if name.replace(".weight", "") in qwen_smoother.keys():
|
||||
smoother = qwen_smoother[name.replace(".weight", "")]
|
||||
starmap_arg = (
|
||||
0,
|
||||
saved_dir,
|
||||
infer_tp,
|
||||
f"{converted_name}.smoother".replace(".weight", ""),
|
||||
smoother,
|
||||
storage_type,
|
||||
None,
|
||||
{
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": multi_query_mode,
|
||||
"local_dim": None,
|
||||
},
|
||||
)
|
||||
if args.processes > 1:
|
||||
starmap_args.append(starmap_arg)
|
||||
else:
|
||||
split_and_save_weight(*starmap_arg)
|
||||
|
||||
param = transpose_weights(name, param)
|
||||
if converted_name in global_weights:
|
||||
torch_to_numpy(param.to(storage_type).cpu()).tofile(
|
||||
saved_dir / f"{converted_name}.bin")
|
||||
else:
|
||||
if 'q_attn' in name:
|
||||
param = concat_qkv_weight_bias(param, name, model)
|
||||
converted_name = converted_name.replace("query",
|
||||
"query_key_value")
|
||||
# Needed by QKV projection weight split. With multi_query_mode one does not simply take
|
||||
# out_dim and divide it by 3 to get local_dim because out_dim = local_dim + 2 * head_size
|
||||
local_dim = model.transformer.h[
|
||||
0].attn.embed_dim if multi_query_mode else None
|
||||
starmap_arg = (0, saved_dir, infer_tp, converted_name,
|
||||
param.to(storage_type), storage_type,
|
||||
act_range.get(name.replace(".weight", "")), {
|
||||
"int8_outputs": int8_outputs,
|
||||
"multi_query_mode": multi_query_mode,
|
||||
"local_dim": local_dim
|
||||
})
|
||||
if args.processes > 1:
|
||||
starmap_args.append(starmap_arg)
|
||||
else:
|
||||
split_and_save_weight(*starmap_arg)
|
||||
|
||||
if args.processes > 1:
|
||||
starmap_args = tqdm(starmap_args, desc="saving weights")
|
||||
with multiprocessing.Pool(args.processes) as pool:
|
||||
pool.starmap(split_and_save_weight, starmap_args)
|
||||
|
||||
|
||||
def run_conversion(args: ProgArgs):
|
||||
print("\n=============== Arguments ===============")
|
||||
for key, value in vars(args).items():
|
||||
print(f"{key}: {value}")
|
||||
print("========================================")
|
||||
hf_qwen_converter(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
run_conversion(ProgArgs.parse())
|
||||
1220
examples/qwen/qwen2_weight.py
Normal file
1220
examples/qwen/qwen2_weight.py
Normal file
File diff suppressed because it is too large
Load Diff
564
examples/qwen/qwen_weight.py
Normal file
564
examples/qwen/qwen_weight.py
Normal file
@@ -0,0 +1,564 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import configparser
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
import xtrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_np, str_dtype_to_torch, torch_to_numpy
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import QWenForCausalLM
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def gen_suffix(rank, use_smooth_quant, quant_per_channel):
|
||||
suffix = f"{rank}.bin"
|
||||
if use_smooth_quant:
|
||||
sq_prefix = "int8."
|
||||
if quant_per_channel:
|
||||
sq_prefix += "col."
|
||||
suffix = sq_prefix + suffix
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_layer_idx(name):
|
||||
ss = name.split('.')
|
||||
for s in ss:
|
||||
if s.isdigit():
|
||||
return s
|
||||
return None
|
||||
|
||||
|
||||
def custom_slice(array, begin, end, axis):
|
||||
if axis < 0:
|
||||
axis += len(array.shape)
|
||||
assert axis >= 0 and axis < len(array.shape), \
|
||||
f"Invalid axis {axis} for array with shape {array.shape}"
|
||||
if axis == 0:
|
||||
return array[begin:end]
|
||||
elif axis == 1:
|
||||
return array[:, begin:end]
|
||||
elif axis == 2:
|
||||
return array[:, :, begin:end]
|
||||
elif axis == 3:
|
||||
return array[:, :, :, begin:end]
|
||||
elif axis == 4:
|
||||
return array[:, :, :, :, begin:end]
|
||||
elif axis == 5:
|
||||
return array[:, :, :, :, :, begin:end]
|
||||
elif axis == 6:
|
||||
return array[:, :, :, :, :, :, begin:end]
|
||||
else:
|
||||
raise ValueError(f"Unsupported axis {axis}")
|
||||
|
||||
|
||||
def split(v, tp_size, idx, dim=0):
|
||||
if tp_size == 1:
|
||||
return v
|
||||
if len(v.shape) == 1:
|
||||
if v.shape[0] % tp_size != 0:
|
||||
# padding 0 to align the split
|
||||
pad_tensor = np.zeros([tp_size - v.shape[0] % tp_size],
|
||||
dtype=v.dtype)
|
||||
v = np.concatenate([v, pad_tensor])
|
||||
return np.ascontiguousarray(np.split(v, tp_size)[idx])
|
||||
else:
|
||||
if dim < 0:
|
||||
dim += len(v.shape)
|
||||
slice_size = (v.shape[dim] + tp_size - 1) // tp_size
|
||||
bound = v.shape[dim]
|
||||
nd = custom_slice(v,
|
||||
idx * slice_size,
|
||||
min((idx + 1) * slice_size, bound),
|
||||
axis=dim)
|
||||
if (idx + 1) * slice_size > bound:
|
||||
pad_shape = list(v.shape)
|
||||
pad_shape[dim] = tp_size - v.shape[dim] % tp_size
|
||||
pad_tensor = np.zeros(pad_shape, dtype=v.dtype)
|
||||
nd = np.concatenate([nd, pad_tensor], axis=dim)
|
||||
return np.ascontiguousarray(nd)
|
||||
|
||||
|
||||
def parse_ft_config(ini_file):
|
||||
qwen_config = configparser.ConfigParser()
|
||||
qwen_config.read(ini_file)
|
||||
|
||||
vocab_size = qwen_config.getint('qwen', 'vocab_size')
|
||||
hidden_size = qwen_config.getint('qwen', 'hidden_size')
|
||||
inter_size = qwen_config.getint('qwen', 'intermediate_size', fallback=None)
|
||||
num_hidden_layers = qwen_config.getint(
|
||||
"qwen",
|
||||
"num_hidden_layers",
|
||||
fallback=32,
|
||||
)
|
||||
max_position_embeddings = qwen_config.getint("qwen",
|
||||
"max_position_embeddings",
|
||||
fallback=8192)
|
||||
kv_channels = qwen_config.getint('qwen', 'kv_channels', fallback=128)
|
||||
rotary_pct = qwen_config.getfloat('qwen', 'rotary_pct', fallback=0.0)
|
||||
rotary_emb_base = qwen_config.getint('qwen',
|
||||
'rotary_emb_base',
|
||||
fallback=10000)
|
||||
multi_query_mode = qwen_config.getboolean('qwen',
|
||||
'multi_query_mode',
|
||||
fallback=False)
|
||||
return (vocab_size, hidden_size, inter_size, num_hidden_layers, kv_channels,
|
||||
rotary_pct, rotary_emb_base, multi_query_mode,
|
||||
max_position_embeddings)
|
||||
|
||||
|
||||
def load_from_ft(xtrt_llm_qwen: QWenForCausalLM,
|
||||
dir_path,
|
||||
mapping=Mapping(),
|
||||
dtype='float16',
|
||||
share_embedding_table=False,
|
||||
parallel_embedding_table=False,
|
||||
multi_query_mode=False):
|
||||
xtrt_llm.logger.info('Loading weights from FT...')
|
||||
tik = time.time()
|
||||
quant_mode = getattr(xtrt_llm_qwen, 'quant_mode', QuantMode(0))
|
||||
if quant_mode.is_int8_weight_only():
|
||||
plugin_weight_only_quant_type = torch.int8
|
||||
elif quant_mode.is_int4_weight_only():
|
||||
plugin_weight_only_quant_type = torch.quint4x2
|
||||
(vocab_size, hidden_size, inter_size, num_hidden_layers, kv_channels,
|
||||
rotary_pct, rotary_emb_base, multi_query_mode,
|
||||
max_position_embeddings) = parse_ft_config(Path(dir_path) / 'config.ini')
|
||||
np_dtype = str_dtype_to_np(dtype)
|
||||
|
||||
def fromfile(dir_path, name, shape=None, dtype=np.float16):
|
||||
dtype = np_dtype if dtype is None else dtype
|
||||
p = dir_path + '/' + name
|
||||
if Path(p).exists():
|
||||
t = np.fromfile(p, dtype=dtype)
|
||||
if shape is not None:
|
||||
t = t.reshape(shape)
|
||||
return t
|
||||
else:
|
||||
print(f"Warning: {p} not found.")
|
||||
return None
|
||||
|
||||
def set_smoothquant_scale_factors(
|
||||
module,
|
||||
pre_scale_weight,
|
||||
dir_path,
|
||||
basename,
|
||||
shape,
|
||||
per_tok_dyn,
|
||||
per_channel,
|
||||
is_qkv=False,
|
||||
rank=None,
|
||||
):
|
||||
suffix = "bin"
|
||||
if per_channel:
|
||||
if rank is not None:
|
||||
suffix = f"{rank}." + suffix
|
||||
suffix = "col." + suffix
|
||||
|
||||
col_shape = shape if (per_channel or is_qkv) else [1, 1]
|
||||
if per_tok_dyn:
|
||||
if pre_scale_weight is not None:
|
||||
pre_scale_weight.value = np.array([1.0], dtype=np.float32)
|
||||
t = fromfile(dir_path, f"{basename}scale_w_quant_orig.{suffix}",
|
||||
col_shape, np.float32)
|
||||
module.per_channel_scale.value = t
|
||||
else:
|
||||
t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1],
|
||||
np.float32)
|
||||
pre_scale_weight.value = t
|
||||
t = fromfile(dir_path, f"{basename}scale_y_accum_quant.{suffix}",
|
||||
col_shape, np.float32)
|
||||
module.per_channel_scale.value = t
|
||||
t = fromfile(dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1],
|
||||
np.float32)
|
||||
module.act_scale.value = t
|
||||
|
||||
def set_smoother(module, dir_path, base_name, shape, rank):
|
||||
suffix = f"{rank}.bin"
|
||||
t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape,
|
||||
np.float32)
|
||||
module.smoother.value = t
|
||||
|
||||
# Determine the quantization mode.
|
||||
quant_mode = getattr(xtrt_llm_qwen, "quant_mode", QuantMode(0))
|
||||
# Do we use SmoothQuant?
|
||||
use_smooth_quant = quant_mode.has_act_and_weight_quant()
|
||||
# Do we use quantization per token?
|
||||
quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
|
||||
# Do we use quantization per channel?
|
||||
quant_per_channel = quant_mode.has_per_channel_scaling()
|
||||
|
||||
# Do we use INT4/INT8 weight-only?
|
||||
use_weight_only = quant_mode.is_weight_only()
|
||||
|
||||
# Int8 KV cache
|
||||
use_int8_kv_cache = quant_mode.has_int8_kv_cache()
|
||||
|
||||
# Debug
|
||||
suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
|
||||
# The type of weights.
|
||||
w_type = np_dtype if not use_smooth_quant else np.int8
|
||||
|
||||
if mapping.is_first_pp_rank():
|
||||
xtrt_llm_qwen.vocab_embedding.weight.value = (fromfile(
|
||||
dir_path, 'vocab_embedding.weight.bin', [vocab_size, hidden_size]))
|
||||
|
||||
if mapping.is_last_pp_rank():
|
||||
xtrt_llm_qwen.ln_f.weight.value = (fromfile(dir_path,
|
||||
'ln_f.weight.bin'))
|
||||
|
||||
lm_head_weight = fromfile(dir_path, 'lm_head.weight.bin',
|
||||
[vocab_size, hidden_size])
|
||||
|
||||
if vocab_size % mapping.tp_size != 0:
|
||||
# padding
|
||||
vocab_size_padded = xtrt_llm_qwen.lm_head.out_features * mapping.tp_size
|
||||
pad_width = vocab_size_padded - vocab_size
|
||||
lm_head_weight = np.pad(lm_head_weight, ((0, pad_width), (0, 0)),
|
||||
'constant',
|
||||
constant_values=0)
|
||||
if mapping.is_last_pp_rank():
|
||||
xtrt_llm_qwen.lm_head.weight.value = np.ascontiguousarray(
|
||||
split(lm_head_weight, mapping.tp_size, mapping.tp_rank))
|
||||
|
||||
layers_range = list(
|
||||
range(mapping.pp_rank * xtrt_llm_qwen.num_layers,
|
||||
(mapping.pp_rank + 1) * xtrt_llm_qwen.num_layers, 1))
|
||||
|
||||
for i in layers_range:
|
||||
c_attn_out_dim = (3 * hidden_size //
|
||||
mapping.tp_size) if not multi_query_mode else (
|
||||
hidden_size // mapping.tp_size +
|
||||
(hidden_size // num_hidden_layers) * 2)
|
||||
|
||||
xtrt_llm_qwen.layers[i].ln_1.weight.value = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) + '.ln_1.weight.bin')
|
||||
|
||||
dst = xtrt_llm_qwen.layers[i].ln_2.weight
|
||||
dst.value = fromfile(dir_path,
|
||||
'model.layers.' + str(i) + '.ln_2.weight.bin')
|
||||
|
||||
t = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.qkv.weight.' + suffix,
|
||||
[hidden_size, c_attn_out_dim], w_type)
|
||||
if t is not None:
|
||||
dst = xtrt_llm_qwen.layers[i].attention.qkv.weight
|
||||
if use_smooth_quant:
|
||||
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
set_smoothquant_scale_factors(
|
||||
xtrt_llm_qwen.layers[i].attention.qkv,
|
||||
xtrt_llm_qwen.layers[i].ln_1.scale_to_int,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.qkv.',
|
||||
[1, c_attn_out_dim],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
rank=mapping.tp_rank,
|
||||
is_qkv=True)
|
||||
elif use_weight_only:
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(t), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[i].attention.qkv.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
|
||||
dst = xtrt_llm_qwen.layers[i].attention.qkv.bias
|
||||
t = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) + '.attention.qkv.bias.' +
|
||||
str(mapping.tp_rank) + '.bin', [c_attn_out_dim])
|
||||
dst.value = np.ascontiguousarray(t)
|
||||
|
||||
dst = xtrt_llm_qwen.layers[i].attention.dense.weight
|
||||
t = fromfile(
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense.weight.' + suffix,
|
||||
[hidden_size // mapping.tp_size, hidden_size], w_type)
|
||||
if use_smooth_quant:
|
||||
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
dense_scale = getattr(xtrt_llm_qwen.layers[i].attention,
|
||||
"quantization_scaling_factor", None)
|
||||
set_smoothquant_scale_factors(
|
||||
xtrt_llm_qwen.layers[i].attention.dense,
|
||||
dense_scale,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense.',
|
||||
[1, hidden_size],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
)
|
||||
set_smoother(xtrt_llm_qwen.layers[i].attention.dense, dir_path,
|
||||
'model.layers.' + str(i) + '.attention.dense',
|
||||
[1, hidden_size // mapping.tp_size], mapping.tp_rank)
|
||||
|
||||
elif use_weight_only:
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(t), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[i].attention.dense.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
|
||||
|
||||
t = fromfile(dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.w1.weight.' + suffix,
|
||||
[hidden_size, inter_size // mapping.tp_size // 2], w_type)
|
||||
if use_smooth_quant:
|
||||
xtrt_llm_qwen.layers[
|
||||
i].mlp.gate.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
set_smoothquant_scale_factors(
|
||||
xtrt_llm_qwen.layers[i].mlp.gate,
|
||||
xtrt_llm_qwen.layers[i].ln_2.scale_to_int,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.w1.',
|
||||
[1, inter_size // mapping.tp_size // 2],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
rank=mapping.tp_rank)
|
||||
elif use_weight_only:
|
||||
dst = xtrt_llm_qwen.layers[i].mlp.gate.weight
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(t), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[i].mlp.gate.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
xtrt_llm_qwen.layers[
|
||||
i].mlp.gate.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
|
||||
t = fromfile(dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.w2.weight.' + suffix,
|
||||
[hidden_size, inter_size // mapping.tp_size // 2], w_type)
|
||||
if use_smooth_quant:
|
||||
xtrt_llm_qwen.layers[i].mlp.fc.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
set_smoothquant_scale_factors(
|
||||
xtrt_llm_qwen.layers[i].mlp.fc,
|
||||
xtrt_llm_qwen.layers[i].ln_2.scale_to_int,
|
||||
dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.w2.',
|
||||
[1, inter_size // mapping.tp_size // 2],
|
||||
quant_per_token_dyn,
|
||||
quant_per_channel,
|
||||
rank=mapping.tp_rank)
|
||||
elif use_weight_only:
|
||||
dst = xtrt_llm_qwen.layers[i].mlp.fc.weight
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(t), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[i].mlp.fc.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
xtrt_llm_qwen.layers[i].mlp.fc.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
|
||||
t = fromfile(dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.c_proj.weight.' + suffix,
|
||||
[inter_size // mapping.tp_size // 2, hidden_size], w_type)
|
||||
if use_smooth_quant:
|
||||
xtrt_llm_qwen.layers[
|
||||
i].mlp.proj.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
proj_scale = getattr(xtrt_llm_qwen.layers[i].mlp,
|
||||
"quantization_scaling_factor", None)
|
||||
set_smoothquant_scale_factors(
|
||||
xtrt_llm_qwen.layers[i].mlp.proj, proj_scale, dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.c_proj.', [1, hidden_size],
|
||||
quant_per_token_dyn, quant_per_channel)
|
||||
set_smoother(xtrt_llm_qwen.layers[i].mlp.proj, dir_path,
|
||||
'model.layers.' + str(i) + '.mlp.c_proj',
|
||||
[1, inter_size // mapping.tp_size // 2],
|
||||
mapping.tp_rank)
|
||||
elif use_weight_only:
|
||||
dst = xtrt_llm_qwen.layers[i].mlp.proj.weight
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(t), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[i].mlp.proj.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
xtrt_llm_qwen.layers[
|
||||
i].mlp.proj.weight.value = np.ascontiguousarray(
|
||||
np.transpose(t, [1, 0]))
|
||||
|
||||
if use_int8_kv_cache:
|
||||
t = fromfile(
|
||||
dir_path, 'model.layers.' + str(i) +
|
||||
'.attention.qkv.scale_y_quant_orig.bin', [1], np.float32)
|
||||
xtrt_llm_qwen.layers[
|
||||
i].attention.kv_orig_quant_scale.value = 1.0 / t
|
||||
xtrt_llm_qwen.layers[i].attention.kv_quant_orig_scale.value = t
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
|
||||
|
||||
def load_from_hf_qwen(xtrt_llm_qwen: xtrt_llm.models.QWenForCausalLM,
|
||||
hf_qwen,
|
||||
mapping=Mapping(),
|
||||
max_position_embeddings=8192,
|
||||
rotary_emb_base=10000,
|
||||
kv_channels=128,
|
||||
dtype="float32",
|
||||
multi_query_mode=False):
|
||||
xtrt_llm.logger.info('Loading weights from HF QWen...')
|
||||
tik = time.time()
|
||||
|
||||
quant_mode = getattr(xtrt_llm_qwen, 'quant_mode', QuantMode(0))
|
||||
if quant_mode.is_int8_weight_only():
|
||||
plugin_weight_only_quant_type = torch.int8
|
||||
elif quant_mode.is_int4_weight_only():
|
||||
plugin_weight_only_quant_type = torch.quint4x2
|
||||
# use_weight_only = quant_mode.is_weight_only()
|
||||
use_weight_only = 0
|
||||
|
||||
model_params = dict(hf_qwen.named_parameters())
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
for k, v in tqdm(model_params.items(),
|
||||
total=len(model_params),
|
||||
ncols=80,
|
||||
desc="Converting..."):
|
||||
if isinstance(v, list):
|
||||
v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
|
||||
else:
|
||||
v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
|
||||
if 'transformer.wte.weight' in k:
|
||||
if xtrt_llm_qwen.use_parallel_embedding:
|
||||
v = split(v, mapping.tp_size, mapping.tp_rank,
|
||||
xtrt_llm_qwen.embedding_sharding_dim)
|
||||
if mapping.is_first_pp_rank():
|
||||
xtrt_llm_qwen.vocab_embedding.weight.value = v
|
||||
elif 'transformer.ln_f.weight' in k:
|
||||
xtrt_llm_qwen.ln_f.weight.value = v
|
||||
elif 'lm_head.weight' in k:
|
||||
xtrt_llm_qwen.lm_head.weight.value = np.ascontiguousarray(
|
||||
split(v, mapping.tp_size, mapping.tp_rank))
|
||||
else:
|
||||
layer_idx = extract_layer_idx(k)
|
||||
if layer_idx is None:
|
||||
continue
|
||||
idx = int(layer_idx)
|
||||
if idx >= xtrt_llm_qwen.num_layers:
|
||||
continue
|
||||
if 'ln_1.weight' in k:
|
||||
xtrt_llm_qwen.layers[idx].ln_1.weight.value = v
|
||||
elif 'ln_2.weight' in k:
|
||||
xtrt_llm_qwen.layers[idx].ln_2.weight.value = v
|
||||
elif 'attn.c_attn.weight' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].attention.qkv.weight
|
||||
if multi_query_mode:
|
||||
assert isinstance(v, list) and len(v) == 3
|
||||
wq = split(v[0], mapping.tp_size, mapping.tp_rank)
|
||||
wk = split(v[1], mapping.tp_size, mapping.tp_rank)
|
||||
wv = split(v[2], mapping.tp_size, mapping.tp_rank)
|
||||
split_v = np.concatenate((wq, wk, wv))
|
||||
else:
|
||||
q_emb = v.shape[0] // 3
|
||||
model_emb = v.shape[1]
|
||||
v = v.reshape(3, q_emb, model_emb)
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
|
||||
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size),
|
||||
model_emb)
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(split_v.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[
|
||||
idx].attention.qkv.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
elif 'attn.c_attn.bias' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].attention.qkv.bias
|
||||
if multi_query_mode:
|
||||
assert isinstance(v, list) and len(v) == 3
|
||||
wq = split(v[0], mapping.tp_size, mapping.tp_rank)
|
||||
wk = split(v[1], mapping.tp_size, mapping.tp_rank)
|
||||
wv = split(v[2], mapping.tp_size, mapping.tp_rank)
|
||||
split_v = np.concatenate((wq, wk, wv))
|
||||
else:
|
||||
q_emb = v.shape[0] // 3
|
||||
v = v.reshape(3, q_emb)
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
|
||||
split_v = split_v.reshape(3 * (q_emb // mapping.tp_size))
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
elif 'attn.c_proj.weight' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].attention.dense.weight
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(split_v.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[
|
||||
idx].attention.dense.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
elif 'mlp.w1.weight' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].mlp.gate.weight
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(split_v.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[
|
||||
idx].mlp.gate.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
elif 'mlp.w2.weight' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].mlp.fc.weight
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(split_v.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[idx].mlp.fc.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
elif 'mlp.c_proj.weight' in k:
|
||||
dst = xtrt_llm_qwen.layers[idx].mlp.proj.weight
|
||||
split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
|
||||
if use_weight_only:
|
||||
v = np.ascontiguousarray(split_v.transpose())
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
torch.tensor(v), plugin_weight_only_quant_type)
|
||||
dst.value = processed_torch_weights.numpy()
|
||||
scales = xtrt_llm_qwen.layers[
|
||||
idx].mlp.proj.per_channel_scale
|
||||
scales.value = torch_weight_scales.numpy()
|
||||
else:
|
||||
dst.value = np.ascontiguousarray(split_v)
|
||||
else:
|
||||
print("unknown key: ", k)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
xtrt_llm.logger.info(f'Weights loaded. Total time: {t}')
|
||||
return
|
||||
16
examples/qwen/requirements.txt
Normal file
16
examples/qwen/requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
datasets~=2.3.2
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
transformers==4.37.1
|
||||
accelerate==0.21.0
|
||||
transformers-stream-generator
|
||||
sentencepiece~=0.1.99
|
||||
tiktoken
|
||||
einops
|
||||
|
||||
# optional dependencies
|
||||
gradio==3.40.1
|
||||
mdtex2html
|
||||
sse_starlette
|
||||
aiohttp_sse_client
|
||||
openai
|
||||
209
examples/qwen/smoothquant.py
Normal file
209
examples/qwen/smoothquant.py
Normal file
@@ -0,0 +1,209 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Utilities for SmoothQuant models
|
||||
'''
|
||||
|
||||
import functools
|
||||
import os
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
project_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(project_dir)
|
||||
from utils.utils import make_context
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_smoothing(scales,
|
||||
gemm_weights,
|
||||
rmsnorm_weights=None,
|
||||
dtype=torch.float32,
|
||||
rmsnorm_1p=False):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
|
||||
if rmsnorm_weights is not None:
|
||||
assert rmsnorm_weights.numel() == scales.numel()
|
||||
rmsnorm_weights.div_(scales).to(dtype)
|
||||
if rmsnorm_1p:
|
||||
rmsnorm_weights += (1 / scales) - 1
|
||||
|
||||
for gemm in gemm_weights:
|
||||
gemm.mul_(scales.view(1, -1)).to(dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm(gemm_weights,
|
||||
act_scales,
|
||||
rmsnorm_weights=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, gemm_weights, rmsnorm_weights, orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm_mlp(w1_weights,
|
||||
w2_weights,
|
||||
act_scales,
|
||||
rmsnorm_weights=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
gemm_weights = []
|
||||
if not isinstance(w1_weights, list):
|
||||
w1_weights = [w1_weights]
|
||||
if not isinstance(w2_weights, list):
|
||||
w2_weights = [w2_weights]
|
||||
|
||||
for i in range(len(w1_weights)):
|
||||
gemm_weight = torch.cat([w1_weights[i], w2_weights[i]], dim=0)
|
||||
gemm_weights.append(gemm_weight)
|
||||
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, w1_weights + w2_weights, rmsnorm_weights,
|
||||
orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat(
|
||||
[fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
if ln is not None:
|
||||
ln.weight.div_(scales)
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def capture_activation_range(
|
||||
model,
|
||||
tokenizer,
|
||||
dataset,
|
||||
system_prompt,
|
||||
chat_format,
|
||||
max_input_len,
|
||||
num_samples=512,
|
||||
):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
||||
|
||||
def stat_tensor(name, tensor, act_scales, key):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float()
|
||||
|
||||
if act_scales[name][key] is None:
|
||||
act_scales[name][key] = comming_max
|
||||
else:
|
||||
act_scales[name][key] = torch.max(act_scales[name][key],
|
||||
comming_max)
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x, act_scales, "x")
|
||||
stat_tensor(name, y, act_scales, "y")
|
||||
|
||||
if act_scales[name]["w"] is None:
|
||||
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
||||
None).max(dim=1)[0]
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
||||
hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(stat_input_hook, name=name)))
|
||||
num_samples = min(num_samples, len(dataset))
|
||||
for i in tqdm(range(num_samples), desc="calibrating model"):
|
||||
line = dataset[i]["article"]
|
||||
line = line + ' TL;DR: '
|
||||
line = line.strip()
|
||||
line = line.replace(" n't", "n't")
|
||||
# use make_content to generate prompt
|
||||
_, input_id_list = make_context(tokenizer=tokenizer,
|
||||
query=line,
|
||||
history=[],
|
||||
system=system_prompt,
|
||||
chat_format=chat_format,
|
||||
max_input_length=max_input_len)
|
||||
line_encoded = torch.from_numpy(np.array(
|
||||
input_id_list, dtype=np.int32)).type(torch.int32).unsqueeze(0)
|
||||
line_encoded = line_encoded.to(device)
|
||||
model(line_encoded)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
14
examples/qwen/utils/__init__.py
Normal file
14
examples/qwen/utils/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
304
examples/qwen/utils/convert.py
Normal file
304
examples/qwen/utils/convert.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for exporting a model to our custom format.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from xtrt_llm._utils import torch_to_numpy
|
||||
|
||||
|
||||
def cpu_map_location(storage, loc):
|
||||
return storage.cpu()
|
||||
|
||||
|
||||
def gpu_map_location(storage, loc):
|
||||
if loc.startswith("cuda"):
|
||||
training_gpu_idx = int(loc.split(":")[1])
|
||||
inference_gpu_idx = training_gpu_idx % torch.cuda.device_count()
|
||||
return storage.cuda(inference_gpu_idx)
|
||||
elif loc.startswith("cpu"):
|
||||
return storage.cpu()
|
||||
else:
|
||||
raise ValueError(f"Not handled {loc}")
|
||||
|
||||
|
||||
def save_val(val, dir, key, tp_num=None):
|
||||
suffix = "bin" if tp_num is None else f"{tp_num}.bin"
|
||||
val.tofile(dir / f"model.{key}.{suffix}")
|
||||
|
||||
|
||||
def save_split(split_vals, dir, key, i, split_factor):
|
||||
for j, val in enumerate(split_vals):
|
||||
save_val(val, dir, key, i * split_factor + j)
|
||||
|
||||
|
||||
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
|
||||
"""
|
||||
This function has two purposes:
|
||||
- compute quantized weights, scaled either per-tensor or per-column
|
||||
- compute scaling factors
|
||||
|
||||
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
|
||||
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
|
||||
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
|
||||
|
||||
Here is the list of what we need (T means per-tensor, C per-column):
|
||||
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
|
||||
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
|
||||
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
|
||||
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
|
||||
to quant range (int8) (used for CUBLAS) (T, C)
|
||||
|
||||
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
|
||||
but then the model would change depending on the number of GPUs used.
|
||||
|
||||
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
|
||||
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
|
||||
"""
|
||||
|
||||
# compute weight scaling factors for fp->int8 and int8->fp
|
||||
if is_qkv and not multi_query_mode:
|
||||
scale_w_orig_quant_t = 127. / torch_to_numpy(act_range["w"].reshape(
|
||||
3, -1).max(dim=-1, keepdims=True)[0].cpu()).astype(np.float32)
|
||||
scale_w_orig_quant_c = 127. / torch_to_numpy(act_range["w"].reshape(
|
||||
3, -1).cpu()).astype(np.float32)
|
||||
elif is_qkv and multi_query_mode:
|
||||
raise ValueError(
|
||||
f"Multi-query w/ int8 quant has not been supported yet")
|
||||
else:
|
||||
scale_w_orig_quant_t = 127. / torch_to_numpy(
|
||||
act_range["w"].max().cpu()).astype(np.float32)
|
||||
scale_w_orig_quant_c = 127. / torch_to_numpy(
|
||||
act_range["w"].cpu()).astype(np.float32)
|
||||
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
|
||||
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
|
||||
|
||||
# compute the rest of needed scaling factors
|
||||
scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item())
|
||||
scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item())
|
||||
scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.)
|
||||
scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_t)
|
||||
scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t *
|
||||
scale_w_orig_quant_c)
|
||||
if is_qkv:
|
||||
scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t,
|
||||
scale_w_orig_quant_c.shape)
|
||||
|
||||
to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8)
|
||||
return {
|
||||
"weight.int8": to_i8(weights * scale_w_orig_quant_t),
|
||||
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
|
||||
"scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32),
|
||||
"scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32),
|
||||
"scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32),
|
||||
"scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32),
|
||||
"scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32),
|
||||
"scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def write_int8(vals,
|
||||
dir,
|
||||
base_key,
|
||||
split_dim,
|
||||
tp_rank,
|
||||
split_factor,
|
||||
kv_cache_only=False):
|
||||
if not kv_cache_only:
|
||||
save_split(np.split(vals["weight.int8"], split_factor, axis=split_dim),
|
||||
dir, f"{base_key}.weight.int8", tp_rank, split_factor)
|
||||
save_split(
|
||||
np.split(vals["weight.int8.col"], split_factor, axis=split_dim),
|
||||
dir, f"{base_key}.weight.int8.col", tp_rank, split_factor)
|
||||
|
||||
saved_keys_once = ["scale_y_quant_orig"]
|
||||
if not kv_cache_only:
|
||||
saved_keys_once += [
|
||||
"scale_x_orig_quant", "scale_w_quant_orig", "scale_y_accum_quant"
|
||||
]
|
||||
# per-column scaling factors are loaded per-gpu for ColumnParallel GEMMs (QKV, FC1)
|
||||
if not kv_cache_only:
|
||||
if split_dim == -1:
|
||||
save_split(
|
||||
np.split(vals["scale_w_quant_orig.col"],
|
||||
split_factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_w_quant_orig.col", tp_rank, split_factor)
|
||||
save_split(
|
||||
np.split(vals["scale_y_accum_quant.col"],
|
||||
split_factor,
|
||||
axis=split_dim), dir,
|
||||
f"{base_key}.scale_y_accum_quant.col", tp_rank, split_factor)
|
||||
else:
|
||||
saved_keys_once += [
|
||||
"scale_w_quant_orig.col", "scale_y_accum_quant.col"
|
||||
]
|
||||
|
||||
if tp_rank == 0:
|
||||
for save_key in saved_keys_once:
|
||||
save_val(vals[save_key], dir, f"{base_key}.{save_key}")
|
||||
|
||||
|
||||
# Note: in multi_query_mode, only query heads are split between multiple GPUs, while key/value head
|
||||
# are not split as there is only one head per key/value.
|
||||
@torch.no_grad()
|
||||
def split_and_save_weight(tp_rank, saved_dir, split_factor, key, vals,
|
||||
storage_type, act_range, config):
|
||||
use_attention_nemo_shape = config.get("use_attention_nemo_shape", False)
|
||||
split_gated_activation = config.get("split_gated_activation", False)
|
||||
num_attention_heads = config.get("num_attention_heads", 0)
|
||||
tp_size = config.get("tp_size", 1)
|
||||
int8_outputs = config.get("int8_outputs", None)
|
||||
multi_query_mode = config.get("multi_query_mode", False)
|
||||
local_dim = config.get("local_dim", None)
|
||||
|
||||
save_int8 = int8_outputs == "all" or int8_outputs == "kv_cache_only"
|
||||
|
||||
if not key.endswith(".smoother"):
|
||||
if not isinstance(vals, list):
|
||||
vals = [vals]
|
||||
|
||||
if config.get("transpose_weights", False) and vals[0].ndim == 2:
|
||||
vals = [val.T for val in vals]
|
||||
if "layernorm.weight" in key and config.get("apply_layernorm_1p",
|
||||
False):
|
||||
vals = [val + 1.0 for val in vals]
|
||||
vals = [torch_to_numpy(val.cpu().to(storage_type)) for val in vals]
|
||||
else:
|
||||
vals = torch_to_numpy(vals.cpu())
|
||||
|
||||
if "ln_1.weight" in key or "ln_1.bias" in key or \
|
||||
"attention.dense.bias" in key or \
|
||||
"ln_2.weight" in key or "ln_2.bias" in key or \
|
||||
"mlp.c_proj.bias" in key or "ln_f.weight" in key:
|
||||
# "final_layernorm.weight" in key or "final_layernorm.bias" in key:
|
||||
|
||||
# shared weights, only need to convert the weights of rank 0
|
||||
if tp_rank == 0:
|
||||
save_val(vals[0], saved_dir, key)
|
||||
|
||||
elif "attention.dense.weight" in key or "mlp.c_proj.weight" in key:
|
||||
cat_dim = 0
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
|
||||
split_factor)
|
||||
|
||||
elif "mlp.w1.weight" in key or "mlp.w2.weight" in key or "mlp.w1.bias" in key or "mlp.w2.bias" in key:
|
||||
if split_gated_activation:
|
||||
splits = [np.split(val, 2, axis=-1) for val in vals]
|
||||
vals, gates = list(zip(*splits))
|
||||
cat_dim = -1
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if act_range is not None and int8_outputs == "all":
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8, saved_dir, base_key, cat_dim, tp_rank,
|
||||
split_factor)
|
||||
|
||||
if split_gated_activation:
|
||||
assert not save_int8
|
||||
prefix, dot, suffix = key.rpartition(".")
|
||||
key = prefix + ".gate" + dot + suffix
|
||||
|
||||
gate = np.concatenate(gates, axis=cat_dim)
|
||||
split_vals = np.split(gate, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
|
||||
elif "attention.qkv.bias" in key:
|
||||
if local_dim is None:
|
||||
local_dim = vals[0].shape[-1] // 3
|
||||
|
||||
if multi_query_mode:
|
||||
val = vals[0]
|
||||
# out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
|
||||
b_q, b_kv = np.split(val, [local_dim], axis=-1)
|
||||
b_q_split = np.split(b_q, split_factor, axis=-1)
|
||||
split_vals = [np.concatenate((i, b_kv), axis=-1) for i in b_q_split]
|
||||
else:
|
||||
if use_attention_nemo_shape:
|
||||
head_num = num_attention_heads // tp_size
|
||||
size_per_head = local_dim // num_attention_heads
|
||||
nemo_shape = (head_num, 3, size_per_head)
|
||||
vals = [val.reshape(nemo_shape) for val in vals]
|
||||
vals = [val.transpose(1, 0, 2) for val in vals]
|
||||
|
||||
vals = [val.reshape(3, local_dim) for val in vals]
|
||||
val = np.concatenate(vals, axis=-1)
|
||||
split_vals = np.split(val, split_factor, axis=-1)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
|
||||
elif "attention.qkv.weight" in key:
|
||||
hidden_dim = vals[0].shape[0]
|
||||
if local_dim is None:
|
||||
local_dim = vals[0].shape[-1] // 3
|
||||
if multi_query_mode:
|
||||
val = vals[0]
|
||||
# out_feature = local_dim + 2 * head_size; assumes local_dim equals to hidden_dim
|
||||
head_size = (val.shape[-1] - local_dim) // 2
|
||||
val = val.reshape(hidden_dim, local_dim + 2 * head_size)
|
||||
w_q, w_kv = np.split(val, [local_dim], axis=-1)
|
||||
w_q_split = np.split(w_q, split_factor, axis=-1)
|
||||
split_vals = [np.concatenate((i, w_kv), axis=-1) for i in w_q_split]
|
||||
else:
|
||||
if use_attention_nemo_shape:
|
||||
head_num = num_attention_heads // tp_size
|
||||
size_per_head = hidden_dim // num_attention_heads
|
||||
vals = [
|
||||
val.reshape(hidden_dim, head_num, 3, size_per_head)
|
||||
for val in vals
|
||||
]
|
||||
vals = [val.transpose(0, 2, 1, 3) for val in vals]
|
||||
|
||||
vals = [val.reshape(hidden_dim, 3, local_dim) for val in vals]
|
||||
cat_dim = -1
|
||||
val = np.concatenate(vals, axis=cat_dim)
|
||||
split_vals = np.split(val, split_factor, axis=cat_dim)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
if save_int8:
|
||||
base_key = key.replace(".weight", "")
|
||||
vals_i8 = generate_int8(val,
|
||||
act_range,
|
||||
is_qkv=True,
|
||||
multi_query_mode=multi_query_mode)
|
||||
write_int8(vals_i8,
|
||||
saved_dir,
|
||||
base_key,
|
||||
cat_dim,
|
||||
tp_rank,
|
||||
split_factor,
|
||||
kv_cache_only=int8_outputs == "kv_cache_only")
|
||||
|
||||
elif "attention.dense.smoother" in key or "mlp.c_proj.smoother" in key:
|
||||
split_vals = np.split(vals, split_factor, axis=0)
|
||||
save_split(split_vals, saved_dir, key, tp_rank, split_factor)
|
||||
else:
|
||||
print(f"[WARNING] {key} not handled by converter")
|
||||
134
examples/qwen/utils/utils.py
Normal file
134
examples/qwen/utils/utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
def make_context(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
system: str = "You are a helpful assistant.",
|
||||
max_input_length:
|
||||
int = 2048, # if you want to change this, you need to change the max_input_len in tensorrt_llm_july-release-v1/examples/qwen/build.py
|
||||
max_window_size: int = 6144,
|
||||
chat_format: str = "chatml",
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
if chat_format == "chatml":
|
||||
im_start, im_end = "<|im_start|>", "<|im_end|>"
|
||||
im_start_tokens = [tokenizer.im_start_id]
|
||||
im_end_tokens = [tokenizer.im_end_id]
|
||||
nl_tokens = tokenizer.encode("\n")
|
||||
|
||||
def _tokenize_str(role, content):
|
||||
return (f"{role}\n{content}",
|
||||
tokenizer.encode(
|
||||
role,
|
||||
allowed_special=set(),
|
||||
) + nl_tokens + tokenizer.encode(
|
||||
content,
|
||||
allowed_special=set(),
|
||||
))
|
||||
|
||||
system_text, system_tokens_part = _tokenize_str("system", system)
|
||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||
raw_text = ""
|
||||
context_tokens = []
|
||||
|
||||
for turn_query, turn_response in reversed(history):
|
||||
query_text, query_tokens_part = _tokenize_str("user", turn_query)
|
||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||
|
||||
response_text, response_tokens_part = _tokenize_str(
|
||||
"assistant", turn_response)
|
||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||||
prev_chat = (
|
||||
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
|
||||
)
|
||||
|
||||
current_context_size = (len(system_tokens) +
|
||||
len(next_context_tokens) +
|
||||
len(context_tokens))
|
||||
if current_context_size < max_window_size:
|
||||
context_tokens = next_context_tokens + context_tokens
|
||||
raw_text = prev_chat + raw_text
|
||||
else:
|
||||
break
|
||||
|
||||
context_tokens = system_tokens + context_tokens
|
||||
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
|
||||
context_tokens += (nl_tokens + im_start_tokens +
|
||||
_tokenize_str("user", query)[1] + im_end_tokens +
|
||||
nl_tokens + im_start_tokens +
|
||||
tokenizer.encode("assistant") + nl_tokens)
|
||||
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
|
||||
|
||||
elif chat_format == "raw":
|
||||
raw_text = query
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
# truncate to max_input_length, truncate from the front
|
||||
return raw_text, context_tokens[-max_input_length:]
|
||||
|
||||
|
||||
def _decode_chatml(tokens: List[int],
|
||||
stop_words: List[str],
|
||||
eod_token_ids: List[int],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
errors: str = 'replace'):
|
||||
end_reason = f"Gen length {len(tokens)}"
|
||||
eod_token_idx = context_length
|
||||
for eod_token_idx in range(context_length, len(tokens)):
|
||||
if tokens[eod_token_idx] in eod_token_ids:
|
||||
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
|
||||
break
|
||||
|
||||
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx],
|
||||
errors=errors)[raw_text_len:]
|
||||
if verbose:
|
||||
print("\nRaw Generate w/o EOD:",
|
||||
tokenizer.decode(tokens, errors=errors)[raw_text_len:])
|
||||
print("\nRaw Generate:", trim_decode_tokens)
|
||||
print("\nEnd Reason:", end_reason)
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print("\nGenerate:", trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def get_stop_words_ids(chat_format, tokenizer):
|
||||
if chat_format == "raw":
|
||||
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
|
||||
elif chat_format == "chatml":
|
||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
|
||||
return stop_words_ids
|
||||
Reference in New Issue
Block a user