add pkgs
This commit is contained in:
8
examples/chatglm/.gitignore
vendored
Normal file
8
examples/chatglm/.gitignore
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
__pycache__/
|
||||
.vscode/
|
||||
awq/
|
||||
chatglm*_6b*/
|
||||
dataset/
|
||||
glm_10b/
|
||||
output_*/
|
||||
model.cache
|
||||
166
examples/chatglm/README.md
Normal file
166
examples/chatglm/README.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# ChatGLM
|
||||
|
||||
This document explains how to build the [ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b), [ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b), [ChatGLM2-6B-32k](https://huggingface.co/THUDM/chatglm2-6b-32k), [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b), [ChatGLM3-6B-Base](https://huggingface.co/THUDM/chatglm3-6b-base), [ChatGLM3-6B-32k](https://huggingface.co/THUDM/chatglm3-6b-32k) models using XTRT-LLM and run on a single XPU, a single node with multiple XPUs or multiple nodes with multiple XPUs.
|
||||
|
||||
## Overview
|
||||
|
||||
The XTRT-LLM ChatGLM implementation can be found in [`xtrt_llm/models/chatglm/model.py`](../../xtrt_llm/models/chatglm/model.py).
|
||||
The XTRT-LLM ChatGLM example code is located in [`examples/chatglm`](./). There are two main files:
|
||||
|
||||
* [`build.py`](./build.py) to build the [XTRT](https://console.cloud.baidu-int.com/devops/icode/repos/baidu/xpu/xmir/tree/master) engine(s) needed to run the ChatGLM model.
|
||||
* [`run.py`](./run.py) to run the inference on an input text.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
| Model Name | FP16 | FMHA | WO | AWQ | SQ | TP | PP | ST | C++ Runtime | benchmark | IFB |
|
||||
| :--------------: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---------: | :-------: | :---: |
|
||||
| chatglm_6b | Y | Y | Y | | | Y | | | | | |
|
||||
| chatglm2_6b | Y | Y | Y | | | Y | | | | | |
|
||||
| chatglm2-6b_32k | Y | Y | Y | | | Y | | | | | |
|
||||
| chatglm3_6b | Y | Y | Y | | | Y | | | | | |
|
||||
| chatglm3_6b_base | Y | Y | Y | | | Y | | | | | |
|
||||
| chatglm3_6b_32k | Y | Y | Y | | | Y | | | | | |
|
||||
| glm_10b | Y | Y | Y | | | Y | | | | | |
|
||||
|
||||
* Model Name: the name of the model, the same as the name on HuggingFace
|
||||
* FMHA: Fused MultiHead Attention (see introduction below)
|
||||
* WO: Weight Only Quantization (int8 / int4)
|
||||
* AWQ: Activation Aware Weight Quantization
|
||||
* SQ: Smooth Quantization
|
||||
* ST: Strongly Typed
|
||||
* TP: Tensor Parallel
|
||||
* PP: Pipeline Parallel
|
||||
* IFB: In-flight Batching (see introduction below)
|
||||
|
||||
## Usage
|
||||
|
||||
The next section describe how to build the engine and run the inference demo.
|
||||
|
||||
### 1. Download repo and weights from HuggingFace Transformers
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
apt-get update
|
||||
apt-get install git-lfs
|
||||
rm -rf chatglm*
|
||||
|
||||
# clone one or more models we want to build
|
||||
git clone https://huggingface.co/THUDM/chatglm-6b chatglm_6b
|
||||
git clone https://huggingface.co/THUDM/chatglm2-6b chatglm2_6b
|
||||
git clone https://huggingface.co/THUDM/chatglm2-6b-32k chatglm2_6b_32k
|
||||
git clone https://huggingface.co/THUDM/chatglm3-6b chatglm3_6b
|
||||
git clone https://huggingface.co/THUDM/chatglm3-6b-base chatglm3_6b_base
|
||||
git clone https://huggingface.co/THUDM/chatglm3-6b-32k chatglm3_6b_32k
|
||||
git clone https://huggingface.co/THUDM/glm-10b glm_10b
|
||||
```
|
||||
|
||||
### 2. Build XTRT engine(s)
|
||||
|
||||
* This ChatGLM example in XTRT-LLM builds XTRT engine(s) using HF checkpoint directly (rather than using FT checkpoints such as GPT example).
|
||||
* If no checkpoint directory is specified, XTRT-LLM will build engine(s) using dummy weights.
|
||||
* The [`build.py`](./build.py) script requires a single XPU to build the XTRT engine(s).
|
||||
* You can enable parallel builds to accelerate the engine building process if you have more than one XPU in your system (of the same model).
|
||||
* For parallel building, add the `--parallel_build` argument to the build command (this feature cannot take advantage of more than a single node).
|
||||
* The number of XTRT engines depends on the number of XPUs that will be used to run inference.
|
||||
* argument [--model_name/-m] is required, which can be one of "chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b", "chatglm3_6b_base", "chatglm3_6b_32k" or "glm-10b" (use "_" rather than "-") for ChatGLM-6B, ChatGLM2-6B, ChatGLM2-6B-32K ChatGLM3-6B, ChatGLM3-6B-Base, ChatGLM3-6B-32K or GLM-10B model respectively.
|
||||
|
||||
#### Examples of build invocations
|
||||
|
||||
```bash
|
||||
# Build a default engine of ChatGLM3-6B on single XPU with FP16, GPT Attention plugin, Gemm plugin, RMS Normolization plugin
|
||||
python3 build.py -m chatglm3_6b
|
||||
|
||||
# Build a engine on single XPU with FMHA kernels (see introduction below), other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b --enable_context_fmha # or --enable_context_fmha_fp32_acc
|
||||
|
||||
# Build a engine on single XPU with int8/int4 Weight-Only quantization, other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b --use_weight_only # or --use_weight_only --weight_only_precision int4
|
||||
|
||||
# Build a engine on single XPU with int8_kv_cache and remove_input_padding, other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b --paged_kv_cache --remove_input_padding
|
||||
|
||||
# Build a engine on two XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b --world_size 2
|
||||
|
||||
# Build a engine of Chatglm-6B on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm_6b
|
||||
|
||||
# Build a engine of Chatglm2-6B on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm2_6b
|
||||
|
||||
# Build a engine of ChatGLM2-6B-32k on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm2_6b-32k
|
||||
|
||||
# Build a engine of ChatGLM3-6B-Base on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b_base
|
||||
|
||||
# Build a engine of ChatGLM3-6B-32k on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m chatglm3_6b-32k
|
||||
|
||||
# Build a engine of GLM-10B on single XPU, other configurations are the same as default example
|
||||
python3 build.py -m glm_10b
|
||||
```
|
||||
|
||||
#### Enabled plugins
|
||||
|
||||
* Use `--use_gpt_attention_plugin <DataType>` to configure GPT Attention plugin (default as float16)
|
||||
* Use `--use_gemm_plugin <DataType>` to configure GEMM plugin (default as float16)
|
||||
* Use `--use_layernorm_plugin <DataType>` (for ChatGLM-6B and GLM-10B models) to configure layernorm normolization plugin (default as float16)
|
||||
* Use `--use_rmsnorm_plugin <DataType>` (for ChatGLM2-6B\* and ChatGLM3-6B\* models) to configure RMS normolization plugin (default as float16)
|
||||
|
||||
|
||||
#### Weight Only quantization
|
||||
|
||||
* Use `--use_weight_only` to enable INT8-Weight-Only quantization, this will siginficantly lower the latency and memory footprint.
|
||||
|
||||
* Furthermore, use `--weight_only_precision int8` or `--weight_only_precision int4` to configure the data type of the weights.
|
||||
|
||||
#### In-flight batching
|
||||
|
||||
* The engine must be built accordingly if [in-flight batching in C++ runtime](../../docs/in_flight_batching.md) will be used.
|
||||
|
||||
* Use `--use_inflight_batching` to enable In-flight Batching.
|
||||
|
||||
* Switch `--use_gpt_attention_plugin=float16`, `--paged_kv_cache`, `--remove_input_padding` will be set when using In-flight Batching.
|
||||
|
||||
* It is possible to use `--use_gpt_attention_plugin float32` In-flight Batching.
|
||||
|
||||
* The size of the block in paged KV cache can be conteoled additionally by using `--tokens_per_block=N`.
|
||||
|
||||
### 3. Run
|
||||
|
||||
#### Single node, single XPU
|
||||
|
||||
```bash
|
||||
# Run the default engine of ChatGLM3-6B on single XPU, other model name is available if built.
|
||||
python3 run.py -m chatglm3_6b
|
||||
# Run the default engine of ChatGLM3-6B on single XPU, using streaming output, other model name is available if built.
|
||||
# In this case only the first sample in the first batch is shown,
|
||||
# But actually all output of all batches are available.
|
||||
python3 run.py -m chatglm3_6b --streaming
|
||||
# Run the default engine of GLM3-10B on single XPU, other model name is available if built.
|
||||
# Token "[MASK]" or "[sMASK]" or "[gMASK]" must be included inside the prompt as the original model commanded.
|
||||
python3 run.py -m chatglm3_6b --input_text "Peking University is [MASK] than Tsinghua Univercity."
|
||||
```
|
||||
|
||||
#### Single node, multi XPU
|
||||
|
||||
```bash
|
||||
# Run the Tensor Parallel 2 engine of ChatGLM3-6B on two XPU, other model name is available if built.
|
||||
mpirun -n 2 python run.py -m chatglm3_6b
|
||||
```
|
||||
|
||||
* `--allow-run-as-root` might be needed if using `mpirun` as root.
|
||||
|
||||
#### Run comparison of performance and accuracy
|
||||
|
||||
```bash
|
||||
# Run the summarization of ChatGLM3-6B task, other model name is available if built.
|
||||
python3 ../summarize.py --test_trt_llm --tokenizer_dir chatglm3_6b --max_input_length 2048
|
||||
```
|
||||
|
||||
### 4. Note
|
||||
|
||||
* [`vllm_test/test_llm_engine.py`](../../vllm_test/test_llm_engine.py) should be run instead of run.py when `--paged_kv_cache` is set.
|
||||
* Accuray of multi-batch chatglm2/3 is not available in padding mode.
|
||||
* `--remove_input_padding` is not available in chatglm_6b.
|
||||
789
examples/chatglm/build.py
Normal file
789
examples/chatglm/build.py
Normal file
@@ -0,0 +1,789 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import tvm.tensorrt as trt
|
||||
# isort: on
|
||||
from visualize import to_onnx
|
||||
from weight import get_scaling_factors, load_from_hf
|
||||
|
||||
import xtrt_llm as tensorrt_llm
|
||||
from xtrt_llm._utils import str_dtype_to_trt
|
||||
from xtrt_llm.builder import Builder
|
||||
from xtrt_llm.logger import logger
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models import ChatGLMHeadModel, quantize_model
|
||||
from xtrt_llm.network import net_guard
|
||||
from xtrt_llm.plugin.plugin import ContextFMHAType
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, pp_size, rank):
|
||||
if pp_size == 1:
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
|
||||
pp_size, rank)
|
||||
|
||||
|
||||
def find_engines(dir: Path,
|
||||
model_name: str = "*",
|
||||
dtype: str = "*",
|
||||
tp_size: str = "*",
|
||||
rank: str = "*") -> List[Path]:
|
||||
template = f"{model_name}_{dtype}_tp{tp_size}_rank{rank}.engine"
|
||||
return [f"{str(dir)}/{template}"]
|
||||
return list(dir.glob(template))
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
'''
|
||||
with open(path, 'wb') as f:
|
||||
f.write(bytearray(engine))
|
||||
'''
|
||||
engine.serialize(str(path))
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def truncate_input_output_len(
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
max_seq_length_from_config,
|
||||
is_fixed_max_position_length=False,
|
||||
):
|
||||
max_seq_length = max_seq_length_from_config
|
||||
if max_input_len >= max_seq_length_from_config:
|
||||
print("Truncate max_input_len as %d" % (max_seq_length_from_config - 1))
|
||||
max_input_len = max_seq_length_from_config - 1
|
||||
max_output_len = 1
|
||||
elif max_input_len + max_output_len > max_seq_length_from_config:
|
||||
print("Truncate max_output_len as %d" %
|
||||
(max_seq_length_from_config - max_input_len))
|
||||
max_output_len = max_seq_length_from_config - max_input_len
|
||||
elif not is_fixed_max_position_length:
|
||||
max_seq_length = max_input_len + max_output_len
|
||||
return max_input_len, max_output_len, max_seq_length
|
||||
|
||||
|
||||
def parse_arguments(args):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
'-m',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
|
||||
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
|
||||
],
|
||||
help=
|
||||
'the name of the model, use "_" rather than "-" to connect the name parts'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now',
|
||||
)
|
||||
parser.add_argument('--tp_size', type=int, default=1)
|
||||
parser.add_argument('--pp_size', type=int, default=1)
|
||||
parser.add_argument('--model_dir', type=Path, default=None)
|
||||
parser.add_argument('--quant_ckpt_path', type=str, default="awq/")
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16'],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
default='model.cache',
|
||||
help=
|
||||
'The path of to read timing cache from, will be ignored if the file does not exist'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
default='info',
|
||||
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
|
||||
)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--max_input_len', type=int, default=1024)
|
||||
parser.add_argument('--max_output_len', type=int, default=1024)
|
||||
parser.add_argument('--max_beam_width', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--use_gpt_attention_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16', False],
|
||||
help=
|
||||
"Activates attention plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_gemm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16', False],
|
||||
help=
|
||||
"Activates GEMM plugin. You can specify the plugin dtype or leave blank to use the model dtype."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_layernorm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16', False],
|
||||
help=
|
||||
"Activates layernorm plugin for ChatGLM-6B / GLM-10B models. You can specify the plugin dtype or leave blank to use the model dtype."
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_rmsnorm_plugin',
|
||||
nargs='?',
|
||||
const='float16',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16', False],
|
||||
help=
|
||||
"Activates rmsnorm plugin for ChatGLM2-6B* / ChatGLM3-6B* models. You can specify the plugin dtype or leave blank to use the model dtype."
|
||||
)
|
||||
parser.add_argument('--gather_all_token_logits',
|
||||
action='store_true',
|
||||
default=False)
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help=
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize XPU.'
|
||||
)
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=Path,
|
||||
default=None,
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--strongly_typed',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--paged_kv_cache',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=
|
||||
'By default we use contiguous KV cache. By setting this flag you enable paged KV cache'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_inflight_batching',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.",
|
||||
)
|
||||
|
||||
# Arguments related to the quantization of the model.
|
||||
parser.add_argument(
|
||||
'--use_smooth_quant',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'Use the SmoothQuant method to quantize activations and weights for the various GEMMs.'
|
||||
'See --per_channel and --per_token for finer-grained quantization options.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_weight_only',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4', 'int4_awq'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size used in GPTQ/AWQ quantization.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--random_seed',
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
'Seed to use when initializing the random number generator for torch.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tokens_per_block',
|
||||
type=int,
|
||||
default=64,
|
||||
help='Number of tokens per block in paged KV cache',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--enable_fp8',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--fp8_kv_cache',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use dtype for KV cache. fp8_kv_cache chooses fp8 quantization for KV'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_custom_all_reduce',
|
||||
action='store_true',
|
||||
help=
|
||||
'Activates latency-optimized algorithm for all-reduce instead of NCCL.',
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
plugins_args = [
|
||||
'use_gpt_attention_plugin',
|
||||
'use_gemm_plugin',
|
||||
'use_layernorm_plugin',
|
||||
'use_rmsnorm_plugin',
|
||||
]
|
||||
for plugin_arg in plugins_args:
|
||||
if getattr(args, plugin_arg) is None:
|
||||
logger.info(
|
||||
f"{plugin_arg} set, without specifying a value. Using {args.dtype} automatically."
|
||||
)
|
||||
setattr(args, plugin_arg, args.dtype)
|
||||
|
||||
assert args.world_size == args.tp_size * args.pp_size # only TP is supported now
|
||||
|
||||
if args.model_dir is None:
|
||||
args.model_dir = Path(args.model_name)
|
||||
if args.output_dir is None:
|
||||
args.output_dir = Path("output_" + args.model_name)
|
||||
with open(args.model_dir / "config.json", "r") as f:
|
||||
js = json.loads(f.read())
|
||||
|
||||
if args.model_name in ["chatglm_6b", "glm_10b"]:
|
||||
assert args.max_input_len < js["max_sequence_length"]
|
||||
|
||||
if args.model_name in ["chatglm_6b"]:
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = False
|
||||
args.ffn_hidden_size = js["inner_hidden_size"]
|
||||
args.hidden_act = 'gelu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = True
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["max_sequence_length"],
|
||||
)
|
||||
args.multi_query_mode = False
|
||||
args.norm_epsilon = js["layernorm_epsilon"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = True
|
||||
args.rmsnorm = False
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = js["use_cache"]
|
||||
args.vocab_size = js["vocab_size"]
|
||||
elif args.model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = js[
|
||||
"apply_residual_connection_post_layernorm"]
|
||||
args.ffn_hidden_size = js["ffn_hidden_size"]
|
||||
args.hidden_act = 'swiglu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = js["add_bias_linear"]
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["seq_length"],
|
||||
)
|
||||
args.multi_query_mode = js["multi_query_attention"]
|
||||
args.norm_epsilon = js["layernorm_epsilon"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["multi_query_group_num"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = js["add_qkv_bias"]
|
||||
args.rmsnorm = js["rmsnorm"]
|
||||
if args.model_name in ["chatglm2_6b_32k", "chatglm3_6b_32k"]:
|
||||
args.rotary_embedding_scaling = js["rope_ratio"]
|
||||
else:
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = js["use_cache"]
|
||||
args.vocab_size = js["padded_vocab_size"]
|
||||
elif args.model_name in ["glm_10b"]:
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = False
|
||||
args.ffn_hidden_size = 4 * js["hidden_size"]
|
||||
args.hidden_act = 'gelu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = True
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["max_sequence_length"],
|
||||
True,
|
||||
)
|
||||
args.multi_query_mode = False
|
||||
args.norm_epsilon = 1.0e-5
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = True
|
||||
args.rmsnorm = False
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = True
|
||||
args.vocab_size = js["vocab_size"]
|
||||
|
||||
if args.use_inflight_batching:
|
||||
if not args.use_gpt_attention_plugin:
|
||||
args.use_gpt_attention_plugin = 'float16'
|
||||
logger.info(
|
||||
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
|
||||
)
|
||||
if not args.remove_input_padding:
|
||||
args.remove_input_padding = True
|
||||
logger.info(
|
||||
"Using remove input padding for inflight batching mode.")
|
||||
if not args.paged_kv_cache:
|
||||
args.paged_kv_cache = True
|
||||
logger.info("Using paged KV cache for inflight batching mode.")
|
||||
|
||||
assert not (
|
||||
args.use_smooth_quant and args.use_weight_only
|
||||
), "You cannot enable both SmoothQuant and INT8 weight-only together."
|
||||
|
||||
if args.use_smooth_quant:
|
||||
args.quant_mode = QuantMode.use_smooth_quant(args.per_token,
|
||||
args.per_channel)
|
||||
elif args.use_weight_only:
|
||||
args.quant_mode = QuantMode.use_weight_only(
|
||||
args.weight_only_precision == 'int4')
|
||||
else:
|
||||
args.quant_mode = QuantMode(0)
|
||||
|
||||
if args.int8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_int8_kv_cache()
|
||||
|
||||
elif args.fp8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
|
||||
if args.enable_fp8:
|
||||
args.quant_mode = args.quant_mode.set_fp8_qdq()
|
||||
|
||||
if args.max_num_tokens is not None:
|
||||
assert args.enable_context_fmha
|
||||
|
||||
logger.info(' Build Arguments '.center(100, '='))
|
||||
for k, v in vars(args).items():
|
||||
logger.info(f' - {k.ljust(30, ".")}: {v}')
|
||||
logger.info('=' * 100)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_rank_engine(
|
||||
builder: Builder,
|
||||
builder_config: tensorrt_llm.builder.BuilderConfig,
|
||||
engine_name: str,
|
||||
rank: int,
|
||||
args: argparse.Namespace,
|
||||
) -> trt.ICudaEngine:
|
||||
'''
|
||||
@brief: Build the engine on the given rank.
|
||||
@param rank: The rank to build the engine.
|
||||
@param args: The cmd line arguments.
|
||||
@return: The built engine.
|
||||
'''
|
||||
# Initialize Module
|
||||
args.mapping = Mapping(
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
assert args.num_layers % args.pp_size == 0, \
|
||||
f"num_layers {args.n_layer} must be a multiple of pipeline "\
|
||||
f"parallelism size {args.pp_size}"
|
||||
trtllm_model = ChatGLMHeadModel(
|
||||
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
|
||||
apply_residual_connection_post_layernorm=args.
|
||||
apply_residual_connection_post_layernorm,
|
||||
dtype=args.dtype,
|
||||
enable_debug_output=args.enable_debug_output,
|
||||
ffn_hidden_size=args.ffn_hidden_size,
|
||||
hidden_act=args.hidden_act,
|
||||
hidden_size=args.hidden_size,
|
||||
linear_bias=args.linear_bias,
|
||||
logits_dtype=args.logits_dtype,
|
||||
mapping=args.mapping,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_seq_length=args.max_seq_length,
|
||||
model_name=args.model_name,
|
||||
norm_epsilon=args.norm_epsilon,
|
||||
num_heads=args.num_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
num_layers=args.num_layers,
|
||||
qkv_bias=args.qkv_bias,
|
||||
quant_mode=args.quant_mode,
|
||||
rmsnorm=args.rmsnorm,
|
||||
rotary_embedding_scaling=args.rotary_embedding_scaling,
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
use_cache=args.use_cache,
|
||||
vocab_size=args.vocab_size,
|
||||
)
|
||||
'''
|
||||
if args.use_smooth_quant or args.use_weight_only:
|
||||
'''
|
||||
if args.use_smooth_quant:
|
||||
trtllm_model = quantize_model(trtllm_model, args.quant_mode)
|
||||
elif args.enable_fp8 or args.fp8_kv_cache:
|
||||
logger.info(f'Loading scaling factors from '
|
||||
f'{args.quantized_fp8_model_path}')
|
||||
quant_scales = get_scaling_factors(args.quantized_fp8_model_path,
|
||||
num_layers=args.n_layer,
|
||||
quant_mode=args.quant_mode)
|
||||
trtllm_model = quantize_model(trtllm_model,
|
||||
quant_mode=args.quant_mode,
|
||||
quant_scales=quant_scales)
|
||||
elif args.use_weight_only:
|
||||
builder_config.trt_builder_config.use_weight_only = args.weight_only_precision
|
||||
|
||||
trtllm_model = load_from_hf(
|
||||
trtllm_model,
|
||||
args.model_dir,
|
||||
mapping=args.mapping,
|
||||
dtype=args.dtype,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
if args.use_gpt_attention_plugin:
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
if not args.enable_fp8:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
else:
|
||||
logger.info(
|
||||
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
|
||||
if args.use_rmsnorm_plugin:
|
||||
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
|
||||
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
if args.enable_context_fmha_fp32_acc:
|
||||
network.plugin_config.set_context_fmha(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
if args.use_weight_only:
|
||||
if args.per_group:
|
||||
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
else:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
args.use_custom_all_reduce)
|
||||
if args.remove_input_padding:
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
if args.paged_kv_cache:
|
||||
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(trtllm_model.named_parameters())
|
||||
|
||||
# Forward
|
||||
inputs = trtllm_model.prepare_inputs(
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_new_tokens=args.max_output_len,
|
||||
use_cache=True,
|
||||
max_beam_width=args.max_beam_width,
|
||||
)
|
||||
trtllm_model(*inputs)
|
||||
if args.enable_debug_output:
|
||||
# mark intermediate nodes' outputs
|
||||
for k, v in trtllm_model.named_network_outputs():
|
||||
v = v.trt_tensor
|
||||
v.name = k
|
||||
network.trt_network.mark_output(v)
|
||||
v.dtype = str_dtype_to_trt(args.dtype)
|
||||
if args.visualize:
|
||||
model_path = args.output_dir / 'test.onnx'
|
||||
to_onnx(network.trt_network, model_path)
|
||||
'''
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
'''
|
||||
|
||||
# Network -> Engine
|
||||
engine = None
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
if rank == 0:
|
||||
config_path = args.output_dir / 'config.json'
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
logger.set_level(args.log_level)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timing_cache_file = args.output_dir / "model.cache"
|
||||
timing_cache = timing_cache_file
|
||||
|
||||
builder = Builder()
|
||||
|
||||
for cur_rank in range(args.world_size):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
|
||||
int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or (
|
||||
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
|
||||
builder_config = builder.create_builder_config(
|
||||
precision=args.dtype,
|
||||
timing_cache=timing_cache,
|
||||
tensor_parallel=args.tp_size,
|
||||
pipeline_parallel=args.pp_size,
|
||||
int8=int8_trt_flag,
|
||||
fp8=args.enable_fp8,
|
||||
strongly_typed=args.strongly_typed,
|
||||
opt_level=args.builder_opt,
|
||||
hardware_compatibility=None,
|
||||
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
|
||||
gather_all_token_logits=args.gather_all_token_logits,
|
||||
hidden_act=args.hidden_act,
|
||||
hidden_size=args.hidden_size,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_input_len=args.max_input_len,
|
||||
max_num_tokens=args.max_output_len + args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_position_embeddings=args.max_seq_length,
|
||||
multi_query_mode=args.multi_query_mode,
|
||||
name=args.model_name,
|
||||
num_heads=args.num_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
inter_size = args.ffn_hidden_size,
|
||||
num_layers=args.num_layers,
|
||||
paged_kv_cache=args.paged_kv_cache,
|
||||
parallel_build=args.parallel_build,
|
||||
quant_mode=args.quant_mode,
|
||||
remove_input_padding=args.remove_input_padding,
|
||||
vocab_size=args.vocab_size,
|
||||
fusion_pattern_list=["remove_dup_mask"],
|
||||
)
|
||||
guard = tensorrt_llm.fusion_patterns.FuseonPatternGuard()
|
||||
print(guard)
|
||||
|
||||
engine_name = get_engine_name(
|
||||
args.model_name,
|
||||
args.dtype,
|
||||
args.world_size,
|
||||
args.pp_size,
|
||||
cur_rank,
|
||||
)
|
||||
engine = build_rank_engine(
|
||||
builder,
|
||||
builder_config,
|
||||
engine_name,
|
||||
cur_rank,
|
||||
args,
|
||||
)
|
||||
assert engine is not None, f'Failed to build engine for rank {cur_rank}'
|
||||
'''
|
||||
local_num_kv_heads = (args.num_kv_heads + args.world_size -
|
||||
1) // args.world_size
|
||||
kv_dtype = str_dtype_to_trt(args.dtype)
|
||||
if args.quant_mode.has_int8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('int8')
|
||||
elif args.quant_mode.has_fp8_kv_cache():
|
||||
kv_dtype = str_dtype_to_trt('fp8')
|
||||
check_gpt_mem_usage(
|
||||
engine=engine,
|
||||
kv_dtype=kv_dtype,
|
||||
use_gpt_attention_plugin=args.use_gpt_attention_plugin,
|
||||
paged_kv_cache=args.paged_kv_cache,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_beam_width=args.max_beam_width,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
local_num_kv_heads=local_num_kv_heads,
|
||||
head_size=args.hidden_size // args.num_heads,
|
||||
num_layers=args.num_layers)
|
||||
'''
|
||||
|
||||
if cur_rank == 0:
|
||||
# Use in-memory timing cache for multiple builder passes.
|
||||
if not args.parallel_build:
|
||||
timing_cache = builder_config.trt_builder_config.get_timing_cache(
|
||||
)
|
||||
|
||||
serialize_engine(engine, args.output_dir / engine_name)
|
||||
del engine
|
||||
'''
|
||||
if rank == 0:
|
||||
ok = builder.save_timing_cache(builder_config, timing_cache_file)
|
||||
assert ok, "Failed to save timing cache."
|
||||
'''
|
||||
|
||||
|
||||
def run_build(args=None):
|
||||
args = parse_arguments(args)
|
||||
|
||||
if args.random_seed is not None:
|
||||
torch.manual_seed(args.random_seed)
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
tik = time.time()
|
||||
if args.parallel_build and args.world_size > 1 and \
|
||||
torch.cuda.device_count() >= args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build XTRT engines. Please make sure that all of the {args.world_size} XPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
logger.info('Serially build XTRT engines.')
|
||||
build(0, args)
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Total time of building all {args.world_size} engines: {t}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_build()
|
||||
40
examples/chatglm/process.py
Normal file
40
examples/chatglm/process.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import re
|
||||
|
||||
|
||||
def process_response_chatglm_6b(responseList):
|
||||
# from chatglm-6b/modeling_chatflm.py
|
||||
for i, response in enumerate(responseList):
|
||||
response = response.strip()
|
||||
punkts = [
|
||||
[",", ","],
|
||||
["!", "!"],
|
||||
[":", ":"],
|
||||
[";", ";"],
|
||||
["\?", "?"],
|
||||
]
|
||||
for item in punkts:
|
||||
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0],
|
||||
r"\1%s" % item[1], response)
|
||||
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0],
|
||||
r"%s\1" % item[1], response)
|
||||
|
||||
responseList[i] = response
|
||||
return responseList
|
||||
|
||||
|
||||
def process_response(responseList):
|
||||
return responseList
|
||||
157
examples/chatglm/quantize.py
Normal file
157
examples/chatglm/quantize.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tensorrt_llm._utils import str_dtype_to_torch
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.quantized.ammo import quantize_and_export
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512,
|
||||
cache_dir=None):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train",
|
||||
cache_dir=cache_dir)
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail",
|
||||
name="3.0.0",
|
||||
split="train",
|
||||
cache_dir=cache_dir)
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
max_length=block_size)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
batch_encoded = batch_encoded.cuda()
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, **kwargs):
|
||||
logger.info(f"Loading tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path,
|
||||
trust_remote_code=True,
|
||||
padding_side="left",
|
||||
**kwargs)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="float16", cache_dir=None):
|
||||
logger.info(f"Loading model from {ckpt_path}")
|
||||
torch_dtype = str_dtype_to_torch(dtype)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
ckpt_path,
|
||||
device_map="auto",
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
model.eval()
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
return model
|
||||
|
||||
|
||||
def parse_arguments(args):
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
'-m',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
|
||||
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
|
||||
],
|
||||
help=
|
||||
'the name of the model, use "_" rather than "-" to connect the name parts'
|
||||
)
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
type=str,
|
||||
choices=['fp8', 'int4_awq'],
|
||||
default='int4_awq',
|
||||
help='Quantization format. Currently only fp8 is supported. '
|
||||
'For int8 smoothquant, use smoothquant.py instead. ')
|
||||
parser.add_argument("--calib_size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of samples for calibration.")
|
||||
parser.add_argument('--model_dir', type=str, default=None)
|
||||
parser.add_argument("--export_path", default="awq")
|
||||
parser.add_argument("--cache_dir",
|
||||
type=str,
|
||||
default="dataset/",
|
||||
help="Directory of dataset cache.")
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
args = parse_arguments(args)
|
||||
|
||||
if args.model_dir is None:
|
||||
args.model_dir = args.model_name
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
tokenizer = get_tokenizer(args.model_dir, cache_dir=args.cache_dir)
|
||||
model = get_model(args.model_dir, args.dtype, cache_dir=args.cache_dir)
|
||||
|
||||
calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
|
||||
calib_size=args.calib_size,
|
||||
cache_dir=args.cache_dir)
|
||||
model = quantize_and_export(model,
|
||||
qformat=args.qformat,
|
||||
calib_dataloader=calib_dataloader,
|
||||
export_path=args.export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
examples/chatglm/requirements.txt
Normal file
5
examples/chatglm/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
datasets~=2.14.5
|
||||
evaluate~=0.4.1
|
||||
protobuf
|
||||
rouge_score~=0.1.2
|
||||
sentencepiece
|
||||
371
examples/chatglm/run.py
Normal file
371
examples/chatglm/run.py
Normal file
@@ -0,0 +1,371 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
import xtrt_llm
|
||||
import xtrt_llm as tensorrt_llm
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
from xtrt_llm.runtime import (ChatGLMGenerationSession, GenerationSession,
|
||||
ModelConfig, SamplingConfig)
|
||||
|
||||
from build import find_engines # isort:skip
|
||||
|
||||
|
||||
def parse_arguments(args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
'-m',
|
||||
type=str,
|
||||
required=True,
|
||||
choices=[
|
||||
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
|
||||
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
|
||||
],
|
||||
help=
|
||||
'the name of the model, use "_" rather than "-" to connect the name parts'
|
||||
)
|
||||
parser.add_argument('--max_output_len', type=int, default=1024)
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default=None)
|
||||
parser.add_argument('--beam_width', type=int, default=1)
|
||||
parser.add_argument('--streaming', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--input_text',
|
||||
type=str,
|
||||
nargs='*',
|
||||
default=[
|
||||
"What's new between ChatGLM3-6B and ChatGLM2-6B?",
|
||||
"Could you introduce NVIDIA Corporation for me?",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input_tokens',
|
||||
type=str,
|
||||
help=
|
||||
'CSV or Numpy file containing tokenized input. Alternative to text input.',
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tokenizer_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Directory containing the tokenizer model.',
|
||||
)
|
||||
parser.add_argument('--temperature', type=float, default=1.0)
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
parser.add_argument('--top_p', type=float, default=0.0)
|
||||
parser.add_argument('--random_seed', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--performance_test_scale',
|
||||
type=str,
|
||||
help=
|
||||
"Scale for performance test. e.g., 8x1024x64 (batch_size, input_text_length, max_output_length)",
|
||||
default="")
|
||||
|
||||
args = parser.parse_args(args)
|
||||
|
||||
if args.engine_dir is None:
|
||||
args.engine_dir = Path("output_" + args.model_name)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = Path(args.engine_dir) / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
dtype = config['builder_config']['precision']
|
||||
max_batch_size = config['builder_config']['max_batch_size']
|
||||
max_input_len = config['builder_config']['max_input_len']
|
||||
max_output_len = config['builder_config']['max_output_len']
|
||||
max_beam_width = config['builder_config']['max_beam_width']
|
||||
remove_input_padding = config['builder_config']['remove_input_padding']
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
tp_size = config['builder_config']['tensor_parallel']
|
||||
pp_size = config['builder_config']['pipeline_parallel']
|
||||
world_size = tp_size * pp_size
|
||||
assert world_size == tensorrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
|
||||
if args.model_name not in ("chatglm_6b", "glm_10b") and len(
|
||||
args.input_text) > 1 and not remove_input_padding:
|
||||
print(
|
||||
"Accuracy of multi-batch chatglm2/3 is not available in padding mode!"
|
||||
)
|
||||
args.input_text = args.input_text[:1]
|
||||
|
||||
if args.max_output_len > max_output_len:
|
||||
print("Truncate max_output_len as %d" % max_output_len)
|
||||
max_output_len = min(max_output_len, args.max_output_len)
|
||||
if args.beam_width > max_beam_width:
|
||||
print("Truncate beam_width as %d" % max_beam_width)
|
||||
beam_width = min(max_beam_width, args.beam_width)
|
||||
|
||||
runtime_rank = tensorrt_llm.mpi_rank()
|
||||
runtime_mapping = tensorrt_llm.Mapping(
|
||||
world_size,
|
||||
runtime_rank,
|
||||
tp_size=world_size,
|
||||
)
|
||||
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
||||
if world_size > 1:
|
||||
import os
|
||||
os.environ["XCCL_GROUP_ID"] = str(runtime_rank // world_size)
|
||||
os.environ["XCCL_NRANKS"] = str(world_size)
|
||||
os.environ["XCCL_CUR_RANK"] = str(runtime_rank % world_size)
|
||||
os.environ["XCCL_DEVICE_ID"] = str(runtime_rank)
|
||||
os.environ["MP_RUN"] = str(1)
|
||||
serialize_path = find_engines(
|
||||
Path(args.engine_dir),
|
||||
model_name=args.model_name,
|
||||
dtype=dtype,
|
||||
tp_size=world_size,
|
||||
rank=runtime_rank,
|
||||
)[0]
|
||||
|
||||
if args.tokenizer_dir is None:
|
||||
args.tokenizer_dir = args.model_name
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir, trust_remote_code=True)
|
||||
end_id = tokenizer.eos_token_id
|
||||
pad_id = tokenizer.pad_token_id
|
||||
if args.model_name in ["glm_10b"]:
|
||||
sop_id = tokenizer.sop_token_id
|
||||
eop_id = tokenizer.eop_token_id
|
||||
input_ids = None
|
||||
input_text = None
|
||||
if args.input_tokens is None:
|
||||
input_text = args.input_text
|
||||
batch_size = len(input_text)
|
||||
if batch_size > max_batch_size:
|
||||
print("Truncate batch_size as %d" % max_batch_size)
|
||||
batch_size = max_batch_size
|
||||
input_text = input_text[:max_batch_size]
|
||||
tokenized = tokenizer(input_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
return_length=True)
|
||||
input_ids = tokenized['input_ids'].int()
|
||||
input_lengths = tokenized['length'].int()
|
||||
max_input_len_real = torch.max(input_lengths)
|
||||
if max_input_len_real > max_input_len:
|
||||
print("Truncate input_length as %d" % max_input_len)
|
||||
input_ids = input_ids[:, :max_input_len]
|
||||
input_lengths = torch.where(input_lengths > max_input_len,
|
||||
max_input_len, input_lengths)
|
||||
else:
|
||||
max_input_len = max_input_len_real
|
||||
if args.model_name in ["glm_10b"]:
|
||||
input_ids = torch.cat(
|
||||
(input_ids, input_ids.new_full((batch_size, 1), sop_id)),
|
||||
dim=-1,
|
||||
)
|
||||
input_lengths += 1
|
||||
max_input_len_real += 1
|
||||
|
||||
else:
|
||||
input_ids = []
|
||||
with open(args.input_tokens) as f_in:
|
||||
for line in f_in:
|
||||
for e in line.strip().split(','):
|
||||
input_ids.append(int(e))
|
||||
input_text = "<ids from file>"
|
||||
input_ids = torch.tensor(input_ids,
|
||||
dtype=torch.int32).cuda().unsqueeze(0)
|
||||
|
||||
if remove_input_padding:
|
||||
input_ids_no_padding = torch.zeros(1,
|
||||
torch.sum(input_lengths),
|
||||
dtype=torch.int32)
|
||||
lengths_acc = torch.cumsum(
|
||||
torch.cat([torch.IntTensor([0]), input_lengths]),
|
||||
dim=0,
|
||||
)
|
||||
for i in range(len(input_ids)):
|
||||
input_ids_no_padding[
|
||||
0, lengths_acc[i]:lengths_acc[i + 1]] = torch.IntTensor(
|
||||
input_ids[i,
|
||||
max_input_len - input_lengths[i]:max_input_len])
|
||||
|
||||
input_ids = input_ids_no_padding
|
||||
|
||||
elif use_gpt_attention_plugin:
|
||||
# when using gpt attention plugin, inputs needs to align at the head
|
||||
input_ids_padding_right = torch.zeros_like(input_ids) + end_id
|
||||
for i, sample in enumerate(input_ids):
|
||||
nPadding = 0
|
||||
for token in sample:
|
||||
if token == pad_id:
|
||||
nPadding += 1
|
||||
else:
|
||||
break
|
||||
input_ids_padding_right[
|
||||
i, :len(sample[nPadding:])] = sample[nPadding:]
|
||||
input_ids = input_ids_padding_right
|
||||
|
||||
model_config = ModelConfig(
|
||||
vocab_size=config['builder_config']['vocab_size'],
|
||||
num_layers=config['builder_config']['num_layers'],
|
||||
num_heads=config['builder_config']['num_heads'] // tp_size,
|
||||
num_kv_heads=(config['builder_config']['num_kv_heads'] + tp_size - 1) //
|
||||
tp_size,
|
||||
hidden_size=config['builder_config']['hidden_size'] // tp_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=config['builder_config']['remove_input_padding'],
|
||||
model_name=args.model_name,
|
||||
paged_kv_cache=config['builder_config']['paged_kv_cache'],
|
||||
quant_mode=QuantMode(config['builder_config']['quant_mode']),
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
end_id=eop_id if args.model_name in ["glm_10b"] else end_id,
|
||||
pad_id=pad_id,
|
||||
num_beams=beam_width,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
sampling_config.random_seed = args.random_seed
|
||||
'''
|
||||
with open(serialize_path, 'rb') as f:
|
||||
engine_buffer = f.read()
|
||||
'''
|
||||
engine_buffer = serialize_path
|
||||
|
||||
if args.model_name in ["chatglm_6b", "glm_10b"]:
|
||||
session = ChatGLMGenerationSession
|
||||
elif args.model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
session = GenerationSession
|
||||
decoder = session(
|
||||
model_config,
|
||||
engine_buffer,
|
||||
runtime_mapping,
|
||||
)
|
||||
|
||||
decoder.setup(
|
||||
len(input_text),
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
beam_width,
|
||||
)
|
||||
output = decoder.decode(
|
||||
input_ids.contiguous().cuda(),
|
||||
input_lengths.contiguous().cuda(),
|
||||
sampling_config,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
streaming=args.streaming,
|
||||
stop_words_list=None if args.model_name in ["chatglm_6b", "glm_10b"]
|
||||
else [tokenizer.eos_token_id],
|
||||
)
|
||||
if args.performance_test_scale != "":
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
for scale in args.performance_test_scale.split("E"):
|
||||
xtrt_llm.logger.info(f"Running performance test with scale {scale}")
|
||||
bs, seqlen, _max_output_len = [int(x) for x in scale.split("x")]
|
||||
_input_ids = torch.from_numpy(
|
||||
np.zeros((bs, seqlen)).astype("int32")).cuda()
|
||||
_input_lengths = torch.from_numpy(
|
||||
np.full((bs, ), seqlen).astype("int32")).cuda()
|
||||
_max_input_length = torch.max(_input_lengths).item()
|
||||
if model_config.remove_input_padding:
|
||||
_input_ids = _input_ids.view((1, -1)).contiguous()
|
||||
|
||||
_t_begin = time.time()
|
||||
decoder.setup(_input_lengths.size(0), _max_input_length,
|
||||
_max_output_len, beam_width)
|
||||
_output_gen_ids = decoder.decode(_input_ids,
|
||||
_input_lengths,
|
||||
sampling_config,
|
||||
streaming=streaming)
|
||||
_t_end = time.time()
|
||||
xtrt_llm.logger.info(
|
||||
f"Total latency: {(_t_end - _t_begin) * 1000:.3f} ms")
|
||||
xtrt_llm.logger.info(
|
||||
f"Throughput: {bs * _max_output_len / (_t_end - _t_begin):.3f} tokens/sec"
|
||||
)
|
||||
exit(0)
|
||||
|
||||
if runtime_rank == 0:
|
||||
if args.model_name in ["chatglm_6b"]:
|
||||
from process import process_response_chatglm_6b as process_response
|
||||
elif args.model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
"glm_10b",
|
||||
]:
|
||||
from process import process_response
|
||||
|
||||
if args.streaming: # streaming output
|
||||
print("#" * 80)
|
||||
# only the first sample in the first batch is shown,
|
||||
# but actually all output of all batches are available
|
||||
print(f"Input idx: {0:2d} ---> len={input_lengths[0]}")
|
||||
print(f'Input: \"{input_text[0]}\""')
|
||||
for output_item in output:
|
||||
output_id = output_item["output_ids"]
|
||||
output_sequence_lengths = output_item["sequence_lengths"]
|
||||
output_id = output_id[0, 0, output_sequence_lengths[0, 0] - 1]
|
||||
output_word = tokenizer.convert_ids_to_tokens(int(output_id))
|
||||
output_word = output_word.replace("▁", " ") # For English
|
||||
output_word = tokenizer.convert_tokens_to_string(output_word)
|
||||
print(output_word, end="", flush=True)
|
||||
print("\n" + "#" * 80)
|
||||
else: # regular output
|
||||
torch.cuda.synchronize()
|
||||
output_ids = output["output_ids"]
|
||||
output_lengths = output["sequence_lengths"]
|
||||
print("#" * 80)
|
||||
for i in range(batch_size):
|
||||
print(f'Input idx: {i:2d} ---> len={input_lengths[i]}')
|
||||
print(f'Input: \"{input_text[i]}\"')
|
||||
print(f"Output idx: {i:2d} --->")
|
||||
output_ids_one_batch = output_ids[i, :, input_lengths[i]:]
|
||||
output_lengths_one_batch = output_lengths[i] - input_lengths[
|
||||
i] + 1
|
||||
output_token_list = tokenizer.batch_decode(
|
||||
output_ids_one_batch, skip_special_tokens=True)
|
||||
output_token_list = process_response(output_token_list)
|
||||
for j, (length, simple_output) in enumerate(
|
||||
zip(output_lengths_one_batch, output_token_list)):
|
||||
print("Beam %2d ---> len=%d" %(j, length))
|
||||
print(f'Output: \"{simple_output}\"')
|
||||
print("#" * 80)
|
||||
|
||||
del decoder
|
||||
|
||||
print(f"Finished from worker {runtime_rank}")
|
||||
128
examples/chatglm/run.sh
Normal file
128
examples/chatglm/run.sh
Normal file
@@ -0,0 +1,128 @@
|
||||
XMLIR_D_XPU_L3_SIZE=0 python3 run.py -m chatglm2_6b --engine_dir engine_outputs --tokenizer_dir downloads/chatglm2-6b --input_text="中华人民共和国主席令
|
||||
(第八十三号)
|
||||
《中华人民共和国刑法》已由中华人民共和国第八届全国人民代表大会第五次会议于1997年3月14日修订,现将修订后的《中华人民共和国刑法》公布,自1997年10月1日起施行。
|
||||
1997年3月14日
|
||||
中华人民共和国刑法
|
||||
(1979年7月1日第五届全国人民代表大会第二次会议通过,
|
||||
1997年3月14日第八届全国人民代表大会第五次会议修订)
|
||||
第一编 总 则
|
||||
第一章 刑法的任务、基本原则和适用范围
|
||||
第一条 为了惩罚犯罪,保护人民,根据宪法,结合我国同犯罪作斗争的具体经验及实际情况,制定本法。
|
||||
第二条 中华人民共和国刑法的任务,是用刑罚同一切犯罪行为作斗争,以保卫国家安全,保卫人民民主专政的政权和社会主义制度,保护国有财产和劳动群众集体所有的财产,保护公民私人所有的财产,保护公民的人身权利、民主权利和其他权利,维护社会秩序、经济秩序,保障社会主义建设事业的顺利进行。
|
||||
第三条 法律明文规定为犯罪行为的,依照法律定罪处刑;法律没有明文规定为犯罪行为的,不得定罪处刑。
|
||||
第四条 对任何人犯罪,在适用法律上一律平等。不允许任何人有超越法律的特权。
|
||||
第五条 刑罚的轻重,应当与犯罪分子所犯罪行和承担的刑事责任相适应。
|
||||
第六条 凡在中华人民共和国领域内犯罪的,除法律有特别规定的以外,都适用本法。
|
||||
凡在中华人民共和国船舶或者航空器内犯罪的,也适用本法。
|
||||
犯罪的行为或者结果有一项发生在中华人民共和国领域内的,就认为是在中华人民共和国领域内犯罪。
|
||||
第七条 中华人民共和国公民在中华人民共和国领域外犯本法规定之罪的,适用本法,但是按本法规定的最高刑为三年以下有期徒刑的,可以不予追究。
|
||||
中华人民共和国国家工作人员和军人在中华人民共和国领域外犯本法规定之罪的,适用本法。
|
||||
第八条 外国人在中华人民共和国领域外对中华人民共和国国家或者公民犯罪,而按本法规定的最低刑为三年以上有期徒刑的,可以适用本法,但是按照犯罪地的法律不受处罚的除外。
|
||||
第九条 对于中华人民共和国缔结或者参加的国际条约所规定的罪行,中华人民共和国在所承担条约义务的范围内行使刑事管辖权的,适用本法。
|
||||
第十条 凡在中华人民共和国领域外犯罪,依照本法应当负刑事责任的,虽然经过外国审判,仍然可以依照本法追究,但是在外国已经受过刑罚处罚的,可以免除或者减轻处罚。
|
||||
第十一条 享有外交特权和豁免权的外国人的刑事责任,通过外交途径解决。
|
||||
第十二条 中华人民共和国成立以后本法施行以前的行为,如果当时的法律不认为是犯罪的,适用当时的法律;如果当时的法律认为是犯罪的,依照本法总则第四章第八节的规定应当追诉的,按照当时的法律追究刑事责任,但是如果本法不认为是犯罪或者处刑较轻的,适用本法。
|
||||
本法施行以前,依照当时的法律已经作出的生效判决,继续有效。
|
||||
第二章 犯罪
|
||||
第一节 犯罪和刑事责任
|
||||
第十三条 一切危害国家主权、领土完整和安全,分裂国家、颠覆人民民主专政的政权和推翻社会主义制度,破坏社会秩序和经济秩序,侵犯国有财产或者劳动群众集体所有的财产,侵犯公民私人所有的财产,侵犯公民的人身权利、民主权利和其他权利,以及其他危害社会的行为,依照法律应当受刑罚处罚的,都是犯罪,但是情节显著轻微危害不大的,不认为是犯罪。
|
||||
第十四条 明知自己的行为会发生危害社会的结果,并且希望或者放任这种结果发生,因而构成犯罪的,是故意犯罪。
|
||||
故意犯罪,应当负刑事责任。
|
||||
第十五条 应当预见自己的行为可能发生危害社会的结果,因为疏忽大意而没有预见,或者已经预见而轻信能够避免,以致发生这种结果的,是过失犯罪。
|
||||
过失犯罪,法律有规定的才负刑事责任。
|
||||
第十六条 行为在客观上虽然造成了损害结果,但是不是出于故意或者过失,而是由于不能抗拒或者不能预见的原因所引起的,不是犯罪。
|
||||
第十七条 已满十六周岁的人犯罪,应当负刑事责任。
|
||||
已满十四周岁不满十六周岁的人,犯故意杀人、故意伤害致人重伤或者死亡、强奸、抢劫、贩卖毒品、放火、爆炸、投毒罪的,应当负刑事责任。
|
||||
已满十四周岁不满十八周岁的人犯罪,应当从轻或者减轻处罚。
|
||||
因不满十六周岁不予刑事处罚的,责令他的家长或者监护人加以管教;在必要的时候,也可以由政府收容教养。
|
||||
第十八条 精神病人在不能辨认或者不能控制自己行为的时候造成危害结果,经法定程序鉴定确认的,不负刑事责任,但是应当责令他的家属或者监护人严加看管和医疗;在必要的时候,由政府强制医疗。
|
||||
间歇性的精神病人在精神正常的时候犯罪,应当负刑事责任。
|
||||
尚未完全丧失辨认或者控制自己行为能力的精神病人犯罪的,应当负刑事责任,但是可以从轻或者减轻处罚。
|
||||
醉酒的人犯罪,应当负刑事责任。
|
||||
第十九条 又聋又哑的人或者盲人犯罪,可以从轻、减轻或者免除处罚。
|
||||
第二十条 为了使国家、公共利益、本人或者他人的人身、财产和其他权利免受正在进行的不法侵害,而采取的制止不法侵害的行为,对不法侵害人造成损害的,属于正当防卫,不负刑事责任。
|
||||
正当防卫明显超过必要限度造成重大损害的,应当负刑事责任,但是应当减轻或者免除处罚。
|
||||
对正在进行行凶、杀人、抢劫、强奸、绑架以及其他严重危及人身安全的暴力犯罪,采取防卫行为,造成不法侵害人伤亡的,不属于防卫过当,不负刑事责任。
|
||||
第二十一条 为了使国家、公共利益、本人或者他人的人身、财产和其他权利免受正在发生的危险,不得已采取的紧急避险行为,造成损害的,不负刑事责任。
|
||||
紧急避险超过必要限度造成不应有的损害的,应当负刑事责任,但是应当减轻或者免除处罚。
|
||||
第一款中关于避免本人危险的规定,不适用于职务上、业务上负有特定责任的人。
|
||||
第二节 犯罪的预备、未遂和中止
|
||||
第二十二条 为了犯罪,准备工具、制造条件的,是犯罪预备。
|
||||
对于预备犯,可以比照既遂犯从轻、减轻处罚或者免除处罚。
|
||||
第二十三条 已经着手实行犯罪,由于犯罪分子意志以外的原因而未得逞的,是犯罪未遂。
|
||||
对于未遂犯,可以比照既遂犯从轻或者减轻处罚。
|
||||
第二十四条 在犯罪过程中,自动放弃犯罪或者自动有效地防止犯罪结果发生的,是犯罪中止。
|
||||
对于中止犯,没有造成损害的,应当免除处罚;造成损害的,应当减轻处罚。
|
||||
第三节 共同犯罪
|
||||
第二十五条 共同犯罪是指二人以上共同故意犯罪。
|
||||
二人以上共同过失犯罪,不以共同犯罪论处;应当负刑事责任的,按照他们所犯的罪分别处罚。
|
||||
第二十六条 组织、领导犯罪集团进行犯罪活动的或者在共同犯罪中起主要作用的,是主犯。
|
||||
三人以上为共同实施犯罪而组成的较为固定的犯罪组织,是犯罪集团。
|
||||
对组织、领导犯罪集团的首要分子,按照集团所犯的全部罪行处罚。
|
||||
对于第三款规定以外的主犯,应当按照其所参与的或者组织、指挥的全部犯罪处罚。
|
||||
第二十七条 在共同犯罪中起次要或者辅助作用的,是从犯。
|
||||
对于从犯,应当从轻、减轻处罚或者免除处罚。
|
||||
第二十八条 对于被胁迫参加犯罪的,应当按照他的犯罪情节减轻处罚或者免除处罚。
|
||||
第二十九条 教唆他人犯罪的,应当按照他在共同犯罪中所起的作用处罚。教唆不满十八周岁的人犯罪的,应当从重处罚。
|
||||
如果被教唆的人没有犯被教唆的罪,对于教唆犯,可以从轻或者减轻处罚。
|
||||
第四节 单位犯罪
|
||||
第三十条 公司、企业、事业单位、机关、团体实施的危害社会的行为,法律规定为单位犯罪的,应当负刑事责任。
|
||||
第三十一条 单位犯罪的,对单位判处罚金,并对其直接负责的主管人员和其他直接责任人员判处刑罚。本法分则和其他法律另有规定的,依照规定。
|
||||
第三章 刑罚
|
||||
第一节 刑罚的种类
|
||||
第三十二条 刑罚分为主刑和附加刑。
|
||||
第三十三条 主刑的种类如下:
|
||||
(一)管制;
|
||||
(二)拘役;
|
||||
(三)有期徒刑;
|
||||
(四)无期徒刑;
|
||||
(五)死刑。
|
||||
第三十四条 附加刑的种类如下:
|
||||
(一)罚金;
|
||||
(二)剥夺政治权利;
|
||||
(三)没收财产。
|
||||
附加刑也可以独立适用。
|
||||
第三十五条 对于犯罪的外国人,可以独立适用或者附加适用驱逐出境。
|
||||
第三十六条 由于犯罪行为而使被害人遭受经济损失的,对犯罪分子除依法给予刑事处罚外,并应根据情况判处赔偿经济损失。
|
||||
承担民事赔偿责任的犯罪分子,同时被判处罚金,其财产不足以全部支付的,或者被判处没收财产的,应当先承担对被害人的民事赔偿责任。
|
||||
第三十七条 对于犯罪情节轻微不需要判处刑罚的,可以免予刑事处罚,但是可以根据案件的不同情况,予以训诫或者责令具结悔过、赔礼道歉、赔偿损失,或者由主管部门予以行政处罚或者行政处分。
|
||||
第二节 管制
|
||||
第三十八条 管制的期限,为三个月以上二年以下。
|
||||
被判处管制的犯罪分子,由公安机关执行。
|
||||
第三十九条 被判处管制的犯罪分子,在执行期间,应当遵守下列规定:
|
||||
(一)遵守法律、行政法规,服从监督;
|
||||
(二)未经执行机关批准,不得行使言论、出版、集会、结社、游行、示威自由的权利;
|
||||
(三)按照执行机关规定报告自己的活动情况;
|
||||
(四)遵守执行机关关于会客的规定;
|
||||
(五)离开所居住的市、县或者迁居,应当报经执行机关批准。
|
||||
对于被判处管制的犯罪分子,在劳动中应当同工同酬。
|
||||
第四十条 被判处管制的犯罪分子,管制期满,执行机关应即向本人和其所在单位或者居住地的群众宣布解除管制。
|
||||
第四十一条 管制的刑期,从判决执行之日起计算;判决执行以前先行羁押的,羁押一日折抵刑期二日。
|
||||
第三节 拘役
|
||||
第四十二条 拘役的期限,为一个月以上六个月以下。
|
||||
第四十三条 被判处拘役的犯罪分子,由公安机关就近执行。
|
||||
在执行期间,被判处拘役的犯罪分子每月可以回家一天至两天;参加劳动的,可以酌量发给报酬。
|
||||
第四十四条 拘役的刑期,从判决执行之日起计算;判决执行以前先行羁押的,羁押一日折抵刑期一日。
|
||||
第四节 有期徒刑、无期徒刑
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
第四十五条 有期徒刑的期限,除本法第五十条、第六十九条规定外,为六个月以上十五年以下。
|
||||
|
||||
问:杀人、抢劫、强奸的犯什么罪?
|
||||
答:"
|
||||
155
examples/chatglm/smoothquant.py
Normal file
155
examples/chatglm/smoothquant.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
'''
|
||||
Utilities for SmoothQuant models
|
||||
'''
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_smoothing(scales,
|
||||
gemm_weights,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
dtype=torch.float32,
|
||||
layernorm_1p=False):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
|
||||
if layernorm_weights is not None:
|
||||
assert layernorm_weights.numel() == scales.numel()
|
||||
layernorm_weights.div_(scales).to(dtype)
|
||||
if layernorm_bias is not None:
|
||||
assert layernorm_bias.numel() == scales.numel()
|
||||
layernorm_bias.div_(scales).to(dtype)
|
||||
if layernorm_1p:
|
||||
layernorm_weights += (1 / scales) - 1
|
||||
|
||||
for gemm in gemm_weights:
|
||||
gemm.mul_(scales.view(1, -1)).to(dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm(gemm_weights,
|
||||
act_scales,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
|
||||
orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat(
|
||||
[fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
if ln is not None:
|
||||
ln.weight.div_(scales)
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def capture_activation_range(model,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
||||
|
||||
def stat_tensor(name, tensor, act_scales, key):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float()
|
||||
|
||||
if act_scales[name][key] is None:
|
||||
act_scales[name][key] = comming_max
|
||||
else:
|
||||
act_scales[name][key] = torch.max(act_scales[name][key],
|
||||
comming_max)
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x, act_scales, "x")
|
||||
stat_tensor(name, y, act_scales, "y")
|
||||
|
||||
if act_scales[name]["w"] is None:
|
||||
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
||||
None).max(dim=0)[0]
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
||||
hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(stat_input_hook, name=name)))
|
||||
|
||||
for i in tqdm(range(num_samples), desc="calibrating model"):
|
||||
input_ids = tokenizer(dataset[i]["text"],
|
||||
return_tensors="pt",
|
||||
max_length=seq_len,
|
||||
truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
73
examples/chatglm/visualize.py
Normal file
73
examples/chatglm/visualize.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import onnx
|
||||
import tvm.tensorrt as trt
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
def trt_dtype_to_onnx(dtype):
|
||||
if dtype == trt.float16:
|
||||
return TensorProto.DataType.FLOAT16
|
||||
elif dtype == trt.float32:
|
||||
return TensorProto.DataType.FLOAT
|
||||
elif dtype == trt.int32:
|
||||
return TensorProto.DataType.INT32
|
||||
else:
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
def to_onnx(network, path):
|
||||
inputs = []
|
||||
for i in range(network.num_inputs):
|
||||
network_input = network.get_input(i)
|
||||
inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
||||
list(network_input.shape)))
|
||||
|
||||
outputs = []
|
||||
for i in range(network.num_outputs):
|
||||
network_output = network.get_output(i)
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
||||
list(network_output.shape)))
|
||||
|
||||
nodes = []
|
||||
for i in range(network.num_layers):
|
||||
layer = network.get_layer(i)
|
||||
layer_inputs = []
|
||||
for j in range(layer.num_inputs):
|
||||
ipt = layer.get_input(j)
|
||||
if ipt is not None:
|
||||
layer_inputs.append(layer.get_input(j).name)
|
||||
layer_outputs = [
|
||||
layer.get_output(j).name for j in range(layer.num_outputs)
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(str(layer.type),
|
||||
name=layer.name,
|
||||
inputs=layer_inputs,
|
||||
outputs=layer_outputs,
|
||||
domain="com.nvidia"))
|
||||
|
||||
onnx_model = helper.make_model(helper.make_graph(nodes,
|
||||
'attention',
|
||||
inputs,
|
||||
outputs,
|
||||
initializer=None),
|
||||
producer_name='NVIDIA')
|
||||
onnx.save(onnx_model, path)
|
||||
590
examples/chatglm/weight.py
Normal file
590
examples/chatglm/weight.py
Normal file
@@ -0,0 +1,590 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
|
||||
import xtrt_llm as tensorrt_llm
|
||||
import xtrt_llm.logger as logger
|
||||
from xtrt_llm._utils import str_dtype_to_torch, torch_to_numpy
|
||||
from xtrt_llm.mapping import Mapping
|
||||
from xtrt_llm.models.quantized.quant import get_dummy_quant_scales
|
||||
from xtrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def split(weight: np.ndarray, tp_size: int, rank: int = 0, dim: int = 0):
|
||||
if tp_size == 1:
|
||||
return weight
|
||||
elif weight.ndim == 1:
|
||||
return np.ascontiguousarray(np.split(weight, tp_size)[rank].copy())
|
||||
return np.ascontiguousarray(
|
||||
np.split(weight, tp_size, axis=dim)[rank].copy())
|
||||
|
||||
|
||||
def split_matrix(weight: np.ndarray, tp_size: int, rank: int, dim: int):
|
||||
return np.ascontiguousarray(split(weight, tp_size, rank, dim=dim))
|
||||
|
||||
|
||||
def tile_kv_weight_bias(v, kv_num_head, tp_size):
|
||||
head_size = v.shape[0] // kv_num_head
|
||||
reps = tp_size // kv_num_head
|
||||
if v.ndim == 1:
|
||||
v = v.reshape(kv_num_head, head_size)[:, None, :]
|
||||
v = v.expand(kv_num_head, reps, head_size).reshape(-1).clone()
|
||||
else:
|
||||
hidden_size = v.shape[1]
|
||||
v = v.reshape(kv_num_head, head_size, hidden_size)[:, None, :, :]
|
||||
v = v.expand(kv_num_head, reps, head_size,
|
||||
hidden_size).reshape(-1, hidden_size).clone()
|
||||
return v
|
||||
|
||||
|
||||
def split_qkv(v, tp_size, rank, hidden_size, num_heads, num_kv_heads):
|
||||
head_size = hidden_size // num_heads
|
||||
if tp_size == 1:
|
||||
return v
|
||||
|
||||
assert v.shape[0] == hidden_size + head_size * num_kv_heads * 2
|
||||
query = v[:hidden_size]
|
||||
key = v[hidden_size:hidden_size + head_size * num_kv_heads]
|
||||
value = v[hidden_size + head_size * num_kv_heads:hidden_size +
|
||||
head_size * num_kv_heads * 2]
|
||||
|
||||
if num_kv_heads < tp_size:
|
||||
key = tile_kv_weight_bias(key, num_kv_heads, tp_size)
|
||||
value = tile_kv_weight_bias(value, num_kv_heads, tp_size)
|
||||
assert (key.shape[0] % (tp_size * head_size)) == 0
|
||||
assert (value.shape[0] % (tp_size * head_size)) == 0
|
||||
|
||||
q_tmp = torch.chunk(query, tp_size, dim=0)[rank]
|
||||
k_tmp = torch.chunk(key, tp_size, dim=0)[rank]
|
||||
v_tmp = torch.chunk(value, tp_size, dim=0)[rank]
|
||||
return torch.concatenate([q_tmp, k_tmp, v_tmp], dim=0).contiguous()
|
||||
|
||||
|
||||
def load_quant_weight(src, value_dst, scale_dst, plugin_weight_only_quant_type):
|
||||
v = torch.transpose(src, dim0=0, dim1=1).contiguous()
|
||||
processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
|
||||
v, plugin_weight_only_quant_type)
|
||||
value_dst.value = torch_to_numpy(processed_torch_weights)
|
||||
scale_dst.value = torch_to_numpy(torch_weight_scales)
|
||||
|
||||
|
||||
def load_from_hf(
|
||||
trt_model,
|
||||
hf_model_dir,
|
||||
mapping=Mapping(),
|
||||
dtype="float32",
|
||||
model_name=None,
|
||||
multi_query_mode=False,
|
||||
):
|
||||
|
||||
assert model_name is not None, "Model name must be set"
|
||||
|
||||
tensorrt_llm.logger.info("Loading weights from HF")
|
||||
|
||||
if not Path(hf_model_dir).exists():
|
||||
tensorrt_llm.logger.info(
|
||||
"No weight file found from %s, use random weights" % hf_model_dir)
|
||||
return trt_model
|
||||
|
||||
tik = time.time()
|
||||
|
||||
hf_model = transformers.AutoModel.from_pretrained(hf_model_dir,
|
||||
trust_remote_code=True)
|
||||
hidden_size = hf_model.config.hidden_size
|
||||
num_heads = hf_model.config.num_attention_heads
|
||||
num_layers = hf_model.config.num_layers
|
||||
|
||||
torch_type = str_dtype_to_torch(dtype)
|
||||
quant_mode = getattr(trt_model, 'quant_mode', QuantMode(0))
|
||||
if quant_mode.is_int8_weight_only():
|
||||
plugin_weight_only_quant_type = torch.int8
|
||||
elif quant_mode.is_int4_weight_only():
|
||||
plugin_weight_only_quant_type = torch.quint4x2
|
||||
use_weight_only = quant_mode.is_weight_only()
|
||||
|
||||
layers_per_pipeline_stage = num_layers // mapping.pp_size
|
||||
layers_range = list(
|
||||
range(mapping.pp_rank * layers_per_pipeline_stage,
|
||||
(mapping.pp_rank + 1) * layers_per_pipeline_stage))
|
||||
feed_weight_count = 0
|
||||
|
||||
if model_name in ["chatglm_6b", "glm_10b"]:
|
||||
num_kv_heads = hf_model.config.num_attention_heads
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
num_kv_heads = hf_model.config.multi_query_group_num
|
||||
|
||||
if mapping.is_first_pp_rank():
|
||||
# Embedding
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.word_embeddings.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.embedding.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.embedding.word_embeddings.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.embedding.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.word_embeddings.weight.to(torch_type).detach()
|
||||
trt_model.embedding.weight.value = torch_to_numpy(weight)
|
||||
weight = hf_model.transformer.position_embeddings.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.position_embeddings.weight.value = torch_to_numpy(weight)
|
||||
weight = hf_model.transformer.block_position_embeddings.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.block_embeddings.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 3
|
||||
|
||||
if mapping.is_last_pp_rank():
|
||||
# Final normalization
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.final_layernorm.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.final_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.final_layernorm.bias.to(
|
||||
torch_type).detach()
|
||||
trt_model.final_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.final_layernorm.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.final_norm.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.final_layernorm.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.final_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.final_layernorm.bias.to(
|
||||
torch_type).detach()
|
||||
trt_model.final_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
|
||||
# Final LM
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.lm_head.weight.to(torch_type).detach()
|
||||
if weight.shape[0] % mapping.tp_size != 0:
|
||||
pad_width = trt_model.lm_head.out_features * mapping.tp_size - weight.shape[
|
||||
0]
|
||||
weight = F.pad(weight, (0, 0, 0, pad_width))
|
||||
split_weight = torch.chunk(weight, mapping.tp_size,
|
||||
dim=0)[mapping.rank]
|
||||
trt_model.lm_head.weight.value = torch_to_numpy(split_weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.output_layer.weight.to(
|
||||
torch_type).detach()
|
||||
if weight.shape[0] % mapping.tp_size != 0:
|
||||
pad_width = trt_model.lm_head.out_features * mapping.tp_size - weight.shape[
|
||||
0]
|
||||
weight = F.pad(weight, (0, 0, 0, pad_width))
|
||||
split_weight = torch.chunk(weight, mapping.tp_size,
|
||||
dim=0)[mapping.rank]
|
||||
trt_model.lm_head.weight.value = torch_to_numpy(split_weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.word_embeddings.weight.to(torch_type).detach()
|
||||
if weight.shape[0] % mapping.tp_size != 0:
|
||||
pad_width = trt_model.lm_head.out_features * mapping.tp_size - weight.shape[
|
||||
0]
|
||||
weight = F.pad(weight, (0, 0, 0, pad_width))
|
||||
split_weight = torch.chunk(weight, mapping.tp_size,
|
||||
dim=0)[mapping.rank]
|
||||
trt_model.lm_head.weight.value = torch_to_numpy(split_weight)
|
||||
feed_weight_count += 1
|
||||
|
||||
# Weight per layer
|
||||
for layer_idx in range(num_layers):
|
||||
if layer_idx not in layers_range:
|
||||
continue
|
||||
i = int(layer_idx) - mapping.pp_rank * layers_per_pipeline_stage
|
||||
if i >= num_layers:
|
||||
continue
|
||||
|
||||
# Pre normalization
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[i].input_layernorm.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.layers[i].pre_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.layers[i].input_layernorm.bias.to(
|
||||
torch_type).detach()
|
||||
trt_model.layers[i].pre_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].input_layernorm.weight.to(torch_type).detach()
|
||||
trt_model.layers[i].pre_norm.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[i].input_layernorm.weight.to(
|
||||
torch_type).detach()
|
||||
trt_model.layers[i].pre_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.layers[i].input_layernorm.bias.to(
|
||||
torch_type).detach()
|
||||
trt_model.layers[i].pre_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
|
||||
# QKV multiplication weight
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[
|
||||
i].attention.query_key_value.weight.to(torch_type).detach()
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].self_attention.query_key_value.weight.to(
|
||||
torch_type).detach()
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[
|
||||
i].attention.query_key_value.weight.to(torch_type).detach()
|
||||
|
||||
split_weight = split_qkv(weight, mapping.tp_size, mapping.tp_rank,
|
||||
hidden_size, num_heads, num_kv_heads)
|
||||
dst = trt_model.layers[i].attention.qkv
|
||||
if use_weight_only:
|
||||
load_quant_weight(
|
||||
src=split_weight,
|
||||
value_dst=dst.weight,
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = torch_to_numpy(split_weight)
|
||||
feed_weight_count += 1
|
||||
|
||||
# QKV multiplication bias
|
||||
if model_name in ["chatglm_6b"]:
|
||||
bias = hf_model.transformer.layers[
|
||||
i].attention.query_key_value.bias.to(torch_type).detach()
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
bias = hf_model.transformer.encoder.layers[
|
||||
i].self_attention.query_key_value.bias.to(torch_type).detach()
|
||||
elif model_name in ["glm_10b"]:
|
||||
bias = hf_model.transformer.layers[
|
||||
i].attention.query_key_value.bias.to(torch_type).detach()
|
||||
|
||||
split_bias = split_qkv(bias, mapping.tp_size, mapping.tp_rank,
|
||||
hidden_size, num_heads, num_kv_heads)
|
||||
trt_model.layers[i].attention.qkv.bias.value = torch_to_numpy(
|
||||
split_bias)
|
||||
feed_weight_count += 1
|
||||
|
||||
# Dense multiplication weight
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[i].attention.dense.weight.to(
|
||||
torch_type).detach()
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].self_attention.dense.weight.to(torch_type).detach()
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[i].attention.dense.weight.to(
|
||||
torch_type).detach()
|
||||
|
||||
split_weight = torch.chunk(weight, mapping.tp_size, dim=1)[mapping.rank]
|
||||
dst = trt_model.layers[i].attention.dense
|
||||
if use_weight_only:
|
||||
load_quant_weight(
|
||||
src=split_weight,
|
||||
value_dst=dst.weight,
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = np.ascontiguousarray(
|
||||
torch_to_numpy(split_weight))
|
||||
feed_weight_count += 1
|
||||
|
||||
# Dense multiplication bias, only GLM-10B
|
||||
if model_name in ["glm_10b", "chatglm_6b"]:
|
||||
bias = hf_model.transformer.layers[i].attention.dense.bias.to(
|
||||
torch_type).detach()
|
||||
split_bias = split_qkv(bias, mapping.tp_size, mapping.tp_rank,
|
||||
hidden_size, num_heads, num_kv_heads)
|
||||
trt_model.layers[i].attention.dense.bias.value = torch_to_numpy(
|
||||
split_bias)
|
||||
feed_weight_count += 1
|
||||
|
||||
# Post normalization
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[
|
||||
i].post_attention_layernorm.weight.to(torch_type).detach()
|
||||
trt_model.layers[i].post_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.layers[
|
||||
i].post_attention_layernorm.bias.to(torch_type).detach()
|
||||
trt_model.layers[i].post_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].post_attention_layernorm.weight.to(torch_type).detach()
|
||||
trt_model.layers[i].post_norm.weight.value = torch_to_numpy(weight)
|
||||
feed_weight_count += 1
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[
|
||||
i].post_attention_layernorm.weight.to(torch_type).detach()
|
||||
trt_model.layers[i].post_norm.weight.value = torch_to_numpy(weight)
|
||||
bias = hf_model.transformer.layers[
|
||||
i].post_attention_layernorm.bias.to(torch_type).detach()
|
||||
trt_model.layers[i].post_norm.bias.value = torch_to_numpy(bias)
|
||||
feed_weight_count += 2
|
||||
|
||||
# Multilayer perceptron h -> 4h weight
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[i].mlp.dense_h_to_4h.weight.to(
|
||||
torch_type).detach()
|
||||
split_weight = torch.chunk(weight, mapping.tp_size,
|
||||
dim=0)[mapping.rank]
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].mlp.dense_h_to_4h.weight.to(torch_type).detach()
|
||||
split_weight = torch.chunk(weight, 2 * mapping.tp_size, dim=0)
|
||||
# swap first and second half weight in columns to adapt trt_llm Swiglu
|
||||
split_weight = torch.cat(
|
||||
[
|
||||
split_weight[mapping.rank + mapping.tp_size],
|
||||
split_weight[mapping.rank],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[i].mlp.dense_h_to_4h.weight.to(
|
||||
torch_type).detach()
|
||||
split_weight = torch.chunk(weight, mapping.tp_size,
|
||||
dim=0)[mapping.rank]
|
||||
|
||||
dst = trt_model.layers[i].mlp.fc
|
||||
if use_weight_only:
|
||||
load_quant_weight(
|
||||
src=split_weight,
|
||||
value_dst=dst.weight,
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = torch_to_numpy(split_weight)
|
||||
feed_weight_count += 1
|
||||
|
||||
# Multilayer perceptron h -> 4h bias, only GLM-10B
|
||||
if model_name in ["glm_10b", "chatglm_6b"]:
|
||||
bias = hf_model.transformer.layers[i].mlp.dense_h_to_4h.bias.to(
|
||||
torch_type).detach()
|
||||
split_bias = split_qkv(bias, mapping.tp_size, mapping.tp_rank,
|
||||
hidden_size, num_heads, num_kv_heads)
|
||||
trt_model.layers[i].mlp.fc.bias.value = torch_to_numpy(split_bias)
|
||||
feed_weight_count += 1
|
||||
|
||||
# Multilayer perceptron 4h -> h weight
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight = hf_model.transformer.layers[i].mlp.dense_4h_to_h.weight.to(
|
||||
torch_type).detach()
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight = hf_model.transformer.encoder.layers[
|
||||
i].mlp.dense_4h_to_h.weight.to(torch_type).detach()
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight = hf_model.transformer.layers[i].mlp.dense_4h_to_h.weight.to(
|
||||
torch_type).detach()
|
||||
|
||||
split_weight = torch.chunk(weight, mapping.tp_size, dim=1)[mapping.rank]
|
||||
dst = trt_model.layers[i].mlp.proj
|
||||
if use_weight_only:
|
||||
load_quant_weight(
|
||||
src=split_weight,
|
||||
value_dst=dst.weight,
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = np.ascontiguousarray(
|
||||
torch_to_numpy(split_weight))
|
||||
feed_weight_count += 1
|
||||
|
||||
# Multilayer perceptron 4h -> h bias, only GLM-10B
|
||||
if model_name in ["glm_10b", "chatglm_6b"]:
|
||||
bias = hf_model.transformer.layers[i].mlp.dense_4h_to_h.bias.to(
|
||||
torch_type).detach()
|
||||
split_bias = split_qkv(bias, mapping.tp_size, mapping.tp_rank,
|
||||
hidden_size, num_heads, num_kv_heads)
|
||||
trt_model.layers[i].mlp.proj.bias.value = torch_to_numpy(split_bias)
|
||||
feed_weight_count += 1
|
||||
|
||||
del hf_model
|
||||
tok = time.time()
|
||||
|
||||
# Final check
|
||||
if model_name in ["chatglm_6b"]:
|
||||
weight_count = 4 + num_layers * 9
|
||||
elif model_name in [
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
weight_count = 3 + num_layers * 7
|
||||
elif model_name in ["glm_10b"]:
|
||||
weight_count = 6 + num_layers * 12
|
||||
if feed_weight_count < weight_count:
|
||||
tensorrt_llm.logger.error("%d weights not loaded from HF" %
|
||||
(weight_count - feed_weight_count))
|
||||
return None
|
||||
tensorrt_llm.logger.info("Loading weights finish in %.2fs" % (tok - tik))
|
||||
return trt_model
|
||||
|
||||
|
||||
def get_scaling_factors(
|
||||
model_path: Union[str, Path],
|
||||
num_layers: int,
|
||||
quant_mode: Optional[QuantMode] = None,
|
||||
) -> Optional[Dict[str, List[int]]]:
|
||||
""" Get the scaling factors for Falcon model
|
||||
|
||||
Returns a dictionary of scaling factors for the selected layers of the
|
||||
Falcon model.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the quantized Falcon model
|
||||
layers (list): List of layers to get the scaling factors for. If None,
|
||||
all layers are selected.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of scaling factors for the selected layers of the
|
||||
Falcon model.
|
||||
|
||||
example:
|
||||
|
||||
{
|
||||
'qkv_act': qkv_act_scale,
|
||||
'qkv_weights': qkv_weights_scale,
|
||||
'qkv_out' : qkv_outputs_scale,
|
||||
'dense_act': dense_act_scale,
|
||||
'dense_weights': dense_weights_scale,
|
||||
'fc_act': fc_act_scale,
|
||||
'fc_weights': fc_weights_scale,
|
||||
'proj_act': proj_act_scale,
|
||||
'proj_weights': proj_weights_scale,
|
||||
}
|
||||
"""
|
||||
|
||||
if model_path is None:
|
||||
logger.warning(f"--quantized_fp8_model_path not specified. "
|
||||
f"Initialize quantization scales automatically.")
|
||||
return get_dummy_quant_scales(num_layers)
|
||||
weight_dict = np.load(model_path)
|
||||
|
||||
# yapf: disable
|
||||
scaling_factor = {
|
||||
'qkv_act': [],
|
||||
'qkv_weights': [],
|
||||
'qkv_output': [],
|
||||
'dense_act': [],
|
||||
'dense_weights': [],
|
||||
'fc_act': [],
|
||||
'fc_weights': [],
|
||||
'proj_act': [],
|
||||
'proj_weights': [],
|
||||
}
|
||||
|
||||
for layer in range(num_layers):
|
||||
scaling_factor['qkv_act'].append(max(
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
|
||||
))
|
||||
scaling_factor['qkv_weights'].append(max(
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
|
||||
weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
|
||||
))
|
||||
if quant_mode is not None and quant_mode.has_fp8_kv_cache():
|
||||
# Not calibrarting KV cache.
|
||||
scaling_factor['qkv_output'].append(1.0)
|
||||
scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
|
||||
scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
|
||||
scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
|
||||
scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
|
||||
scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
|
||||
scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
|
||||
# yapf: enable
|
||||
for k, v in scaling_factor.items():
|
||||
assert len(v) == num_layers, \
|
||||
f'Expect scaling factor {k} of length {num_layers}, got {len(v)}'
|
||||
|
||||
return scaling_factor
|
||||
Reference in New Issue
Block a user