Files
r200_8f_xtrt_llm/examples/gptj/README_CN.md
2025-08-06 15:49:14 +08:00

2.6 KiB
Raw Permalink Blame History

GPT-J

本文档介绍了如何使用昆仑芯XTRT-LLM在单XPU上构建和运行GPT-J模型。

概述

XTRT-LLM GPT-J 示例代码位于 examples/gptj。 此文件夹中有以下几个主要文件:

  • build.py 构建运行GPT-J模型所需的XTRT引擎
  • run.py 基于输入的文字进行推理

支持的矩阵

  • FP16

使用说明

1.从HuggingFaceHF Transformers下载权重

# 1. Weights & config
git clone https://huggingface.co/EleutherAI/gpt-j-6b ./downloads/gptj-6b
pushd ./downloads/gptj-6b && \
  rm -f pytorch_model.bin && \
  wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/pytorch_model.bin && \
popd

# 2. Vocab and merge table
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/vocab.json
wget https://huggingface.co/EleutherAI/gpt-j-6b/resolve/main/merges.txt

2. 构建XTRT引擎

XTRT-LLM从HF checkpoint构建XTRT引擎。如果未指定checkpoint目录XTRT-LLM将使用伪权重构建引擎。

构建调用示例:

# Build a float16 engine using HF weights.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.

python3 build.py --dtype=float16 \
                 --log_level=verbose \
                 --enable_context_fmha \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16 \
                 --max_batch_size=32 \
                 --max_input_len=1919 \
                 --max_output_len=128 \
                 --output_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
                 --model_dir=./downloads/gptj-6b 2>&1 | tee build.log

# Build a float16 engine using dummy weights, useful for performance tests.
# Enable several XTRT-LLM plugins to increase runtime performance. It also helps with build time.

python3 build.py --dtype=float16 \
                 --log_level=verbose \
                 --enable_context_fmha \
                 --use_gpt_attention_plugin float16 \
                 --use_gemm_plugin float16 \
                 --max_batch_size=32 \
                 --max_input_len=1919 \
                 --max_output_len=128 \
                 --output_dir=./downloads/gptj-6b/trt_engines/gptj_engine_dummy_weights 2>&1 | tee build.log

3. 运行

要运行XTRT-LLM GPT-J模型请执行以下操作

python3 run.py --max_output_len=50 \
    --engine_dir=./downloads/gptj-6b/trt_engines/fp16/1-XPU/ \
    --hf_model_location=./downloads/gptj-6b