init
This commit is contained in:
77
examples/api_client.py
Normal file
77
examples/api_client.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Example Python client for vllm.entrypoints.api_server"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from typing import Iterable, List
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def clear_line(n: int = 1) -> None:
|
||||
LINE_UP = '\033[1A'
|
||||
LINE_CLEAR = '\x1b[2K'
|
||||
for _ in range(n):
|
||||
print(LINE_UP, end=LINE_CLEAR, flush=True)
|
||||
|
||||
|
||||
def post_http_request(prompt: str,
|
||||
api_url: str,
|
||||
n: int = 1,
|
||||
stream: bool = False) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": n,
|
||||
"use_beam_search": True,
|
||||
"temperature": 0.0,
|
||||
"max_tokens": 16,
|
||||
"stream": stream,
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json=pload, stream=True)
|
||||
return response
|
||||
|
||||
|
||||
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"]
|
||||
yield output
|
||||
|
||||
|
||||
def get_response(response: requests.Response) -> List[str]:
|
||||
data = json.loads(response.content)
|
||||
output = data["text"]
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=4)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
args = parser.parse_args()
|
||||
prompt = args.prompt
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
n = args.n
|
||||
stream = args.stream
|
||||
|
||||
print(f"Prompt: {prompt!r}\n", flush=True)
|
||||
response = post_http_request(prompt, api_url, n, stream)
|
||||
|
||||
if stream:
|
||||
num_printed_lines = 0
|
||||
for h in get_streaming_response(response):
|
||||
clear_line(num_printed_lines)
|
||||
num_printed_lines = 0
|
||||
for i, line in enumerate(h):
|
||||
num_printed_lines += 1
|
||||
print(f"Beam candidate {i}: {line!r}", flush=True)
|
||||
else:
|
||||
output = get_response(response)
|
||||
for i, line in enumerate(output):
|
||||
print(f"Beam candidate {i}: {line!r}", flush=True)
|
||||
46
examples/aqlm_example.py
Normal file
46
examples/aqlm_example.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='AQLM examples')
|
||||
|
||||
parser.add_argument('--model',
|
||||
'-m',
|
||||
type=str,
|
||||
default=None,
|
||||
help='model path, as for HF')
|
||||
parser.add_argument('--choice',
|
||||
'-c',
|
||||
type=int,
|
||||
default=0,
|
||||
help='known good models by index, [0-4]')
|
||||
parser.add_argument('--tensor_parallel_size',
|
||||
'-t',
|
||||
type=int,
|
||||
default=1,
|
||||
help='tensor parallel size')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
models = [
|
||||
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf",
|
||||
"ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf",
|
||||
"ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf",
|
||||
"ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf",
|
||||
"BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf",
|
||||
]
|
||||
|
||||
model = LLM(args.model if args.model is not None else models[args.choice],
|
||||
tensor_parallel_size=args.tensor_parallel_size)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0)
|
||||
outputs = model.generate("Hello my name is",
|
||||
sampling_params=sampling_params)
|
||||
print(outputs[0].outputs[0].text)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
96
examples/fp8/README.md
Normal file
96
examples/fp8/README.md
Normal file
@@ -0,0 +1,96 @@
|
||||
# FP8 KV Cache
|
||||
|
||||
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.x
|
||||
- PyTorch
|
||||
- NumPy
|
||||
- Hugging Face Transformers
|
||||
- Hugging Face Hub
|
||||
- AMMO
|
||||
|
||||
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
|
||||
1. Install all necessary prerequisites and dependencies.
|
||||
2. Convert HF model into a quantized HF model.
|
||||
3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
4. Load KV Cache Scaling Factors into VLLM.
|
||||
|
||||
### 2. Convert HF model into a quantized HF model.
|
||||
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
|
||||
|
||||
`quantize.py` (examples/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
|
||||
|
||||
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/fp8/quantizer/README.md`.
|
||||
|
||||
### 3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
`extract_scales.py` (examples/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
|
||||
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
|
||||
|
||||
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
|
||||
|
||||
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
|
||||
|
||||
```python
|
||||
# prerequisites:
|
||||
# - Quantized HF LLaMa 2 model
|
||||
python3 examples/fp8/extract_scales.py --help
|
||||
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
|
||||
|
||||
KV Scale Extraction Example
|
||||
|
||||
optional arguments:
|
||||
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
|
||||
Optional arguments:
|
||||
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
|
||||
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
|
||||
--revision: Specify the model's revision number. (Default: None)
|
||||
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
|
||||
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
|
||||
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
|
||||
```
|
||||
```python
|
||||
Example:
|
||||
python3 examples/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
|
||||
```
|
||||
### 4. Load KV Cache Scaling Factors into VLLM.
|
||||
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
|
||||
```python
|
||||
# prerequisites:
|
||||
# - LLaMa 2 kv_cache_scales.json file
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --help
|
||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||
|
||||
Benchmark Throughput Example
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--backend {vllm,hf,mii}
|
||||
--dataset DATASET Path to the dataset.
|
||||
--input-len INPUT_LEN Input prompt length for each request
|
||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||
--model MODEL
|
||||
--tokenizer TOKENIZER
|
||||
--quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None}
|
||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||
--n N Number of generated sequences per prompt.
|
||||
--use-beam-search
|
||||
--num-prompts NUM_PROMPTS Number of prompts to process.
|
||||
--seed SEED
|
||||
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
|
||||
--trust-remote-code trust remote code from huggingface
|
||||
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
|
||||
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
--enforce-eager enforce eager execution
|
||||
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
|
||||
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
|
||||
```
|
||||
```
|
||||
Example:
|
||||
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
|
||||
```python
|
||||
367
examples/fp8/extract_scales.py
Normal file
367
examples/fp8/extract_scales.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
# The main differences are that we add the NPZ format and simplify
|
||||
# its functionality drastically for our purposes (e.g. we assume that
|
||||
# the quantized model exists locally and there is no need to download it)
|
||||
def _prepare_hf_weights(
|
||||
quantized_model_dir: str,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npz":
|
||||
allow_patterns = ["*.npz"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(
|
||||
os.path.join(quantized_model_dir, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{quantized_model_dir}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||
use_safetensors: bool):
|
||||
if load_format == "npz":
|
||||
assert not use_safetensors
|
||||
with np.load(filename) as data:
|
||||
for name in data.files:
|
||||
param = torch.from_numpy(data[name])
|
||||
yield name, param
|
||||
elif use_safetensors:
|
||||
with safe_open(filename, framework="pt") as f:
|
||||
for name in f.keys(): # NOQA: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
state = torch.load(filename, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _kv_scales_extractor(
|
||||
hf_tensor_files: Iterable[str],
|
||||
use_safetensors: bool,
|
||||
rank_keyword: str = "rank",
|
||||
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
|
||||
"""
|
||||
Given a list of files containing tensor data, attempt to extract KV cache
|
||||
scales from these files. Intended as a helper function taking in the output
|
||||
from _prepare_hf_weights.
|
||||
Args:
|
||||
rank_keyword Matches the number immediately after this keyword in the
|
||||
tensor filename to determine the TP rank corresponding
|
||||
to said tensor file
|
||||
expected_tp_size If specified, the TP size of the tensor files is checked
|
||||
against this and an error is raised if they don't match.
|
||||
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
|
||||
The per-rank scales are themselves represented as a dictionary of layer
|
||||
indices to the respective per-layer scale.
|
||||
"""
|
||||
for char in rank_keyword:
|
||||
assert not char.isdecimal(
|
||||
), f"Rank keyword {rank_keyword} contains a numeric character!"
|
||||
rank_scales_map = {}
|
||||
for tensor_file in hf_tensor_files:
|
||||
try:
|
||||
rank_idx = tensor_file.find(rank_keyword)
|
||||
if rank_idx != -1:
|
||||
start_idx = rank_idx + len(rank_keyword)
|
||||
stop_idx = start_idx
|
||||
while stop_idx < len(
|
||||
tensor_file) and tensor_file[stop_idx].isdecimal():
|
||||
stop_idx += 1
|
||||
if stop_idx == start_idx:
|
||||
raise RuntimeError("Did not find rank # in filename.")
|
||||
rank = int(tensor_file[start_idx:stop_idx])
|
||||
elif len(hf_tensor_files) == 1:
|
||||
# Since there is only one tensor file, we can assume
|
||||
# that it's intended for TP rank 0
|
||||
rank = 0
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Filename does not contain '{rank_keyword}'.")
|
||||
except RuntimeError:
|
||||
print("Unable to determine TP rank "
|
||||
f"corresponding to file '{tensor_file}'")
|
||||
raise
|
||||
|
||||
if rank not in rank_scales_map:
|
||||
layer_scales_map = {}
|
||||
rank_scales_map[rank] = layer_scales_map
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tensor file '{tensor_file}' shares TP rank {rank} "
|
||||
"with another tensor file.")
|
||||
|
||||
module_delimiter = ":" if args.load_format == "npz" else "."
|
||||
for name, param in _hf_tensorfile_iterator(tensor_file,
|
||||
args.load_format,
|
||||
use_safetensors):
|
||||
if "kv_cache_scaling_factor" in name:
|
||||
nums = [
|
||||
int(s) for s in name.split(module_delimiter)
|
||||
if s.isdecimal()
|
||||
]
|
||||
assert len(
|
||||
nums) == 1, f"Could not determine layer idx for {name}"
|
||||
layer_idx = nums[0]
|
||||
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
|
||||
f" factor corresponding to layer {layer_idx}"
|
||||
try:
|
||||
layer_scales_map[layer_idx] = param.item()
|
||||
except RuntimeError:
|
||||
print(
|
||||
"This utility supports only per-tensor scalar scales "
|
||||
f"for now. The tensor\n {name} = {param} \nis an "
|
||||
"invalid scale factor.")
|
||||
raise
|
||||
|
||||
if all(
|
||||
len(layer_scales_map) == 0
|
||||
for layer_scales_map in rank_scales_map.values()):
|
||||
# Note: this is true even if the rank_scales_map is empty
|
||||
print("WARNING: No KV cache scale factors found. No output saved.")
|
||||
return None
|
||||
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
|
||||
if expected_tp_size is not None:
|
||||
assert expected_tp_size == empirical_tp_world_size, \
|
||||
f"User expected TP world size = {expected_tp_size} " \
|
||||
"from model but tool is expecting TP world size = " \
|
||||
f"{empirical_tp_world_size} from model instead."
|
||||
for i in range(empirical_tp_world_size):
|
||||
assert i in rank_scales_map, "Expected TP world size = "\
|
||||
f"{empirical_tp_world_size} but did not find KV " \
|
||||
f"cache scaling factors for TP rank {i}"
|
||||
print(f"Found TP world size = {empirical_tp_world_size} "
|
||||
"when extracting KV cache scales!")
|
||||
return rank_scales_map
|
||||
|
||||
|
||||
def _metadata_extractor(quantized_model_dir: str,
|
||||
metadata_extract_fns: \
|
||||
Dict[str, Callable[[Dict[str, Any]], Any]]) \
|
||||
-> Dict[str, Any]:
|
||||
"""
|
||||
Given a directory containing quantized model files, this function
|
||||
aims to extract metadata from the JSON files within this directory.
|
||||
Each JSON file is expected to represent a dictionary in JSON
|
||||
format (referred to as a "JSON-dictionary"). Metadata extraction is
|
||||
defined by a dictionary called metadata_extract_fns, where each
|
||||
metadata field name is mapped to an extraction function.
|
||||
|
||||
These extraction functions are designed to take a JSON-dictionary
|
||||
as their only argument and return the corresponding metadata.
|
||||
While extraction functions are permitted to raise exceptions, they
|
||||
should only raise a KeyError or ValueError if the metadata field
|
||||
cannot be extracted from the current JSON-dictionary, yet there's
|
||||
a possibility of finding it in another JSON-dictionary.
|
||||
|
||||
The function returns a dictionary that maps metadata fields to
|
||||
their extracted data. The keys of this dictionary correspond exactly
|
||||
to those in metadata_extract_fns. If any fields fail to be extracted,
|
||||
their corresponding values are set to None, and a warning is printed.
|
||||
"""
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
|
||||
|
||||
result = {}
|
||||
for file in metadata_files:
|
||||
with open(file) as f:
|
||||
try:
|
||||
metadata = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not parse `{file}` as a valid metadata file,"
|
||||
" skipping it.")
|
||||
continue
|
||||
if not isinstance(metadata, dict):
|
||||
print(f"The file `{file}` does not correspond to a "
|
||||
"JSON-serialized dictionary, skipping it.")
|
||||
continue
|
||||
for metadata_name, extract_fn in metadata_extract_fns.items():
|
||||
try:
|
||||
metadata_info = extract_fn(metadata)
|
||||
if metadata_name not in result:
|
||||
result[metadata_name] = metadata_info
|
||||
elif metadata_info != result[metadata_name]:
|
||||
raise RuntimeError(
|
||||
"Metadata mismatch! Originally found "
|
||||
f"{metadata_name} = {result[metadata_name]} but "
|
||||
f"now found {metadata_name} = {metadata_info} in "
|
||||
f"`{file}`")
|
||||
except KeyError:
|
||||
# It is possible that a given file does not contain some
|
||||
# of our selected metadata as it could be located in some
|
||||
# other metadata file.
|
||||
# 'EFINAE': extract_fn failure is not an error.
|
||||
pass
|
||||
except ValueError:
|
||||
# See above.
|
||||
pass
|
||||
|
||||
# Warn if we cannot find any of the requested metadata
|
||||
for metadata_name in metadata_extract_fns:
|
||||
if metadata_name not in result:
|
||||
print("WARNING: Unable to find requested metadata field "
|
||||
f"`{metadata_name}`, setting it to None.")
|
||||
result[metadata_name] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
metadata_extract_fns = {
|
||||
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
|
||||
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
|
||||
"model_dtype": lambda json_dict: json_dict["dtype"]
|
||||
}
|
||||
recovered_metadata = _metadata_extractor(args.quantized_model,
|
||||
metadata_extract_fns)
|
||||
if args.tp_size is not None:
|
||||
metadata_tp_size = recovered_metadata["tp_size"]
|
||||
if metadata_tp_size is not None:
|
||||
assert args.tp_size == metadata_tp_size, \
|
||||
f"User expected TP world size = {args.tp_size} " \
|
||||
f"but found TP world size = {metadata_tp_size} from metadata!"
|
||||
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
|
||||
rank_keyword = "rank"
|
||||
hf_tensor_files, use_safetensors = _prepare_hf_weights(
|
||||
args.quantized_model, args.load_format)
|
||||
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
|
||||
rank_keyword, expected_tp_size)
|
||||
# Postprocess: formatting to the current schema. Consider pulling it
|
||||
# out into a dedicated function should it ever become more complicated.
|
||||
rank_scales_map = {
|
||||
rank: {k: scale[k]
|
||||
for k in sorted(scale.keys())}
|
||||
for rank, scale in rank_scales_map.items()
|
||||
}
|
||||
# TODO: Expand this with activation and weights scaling factors when
|
||||
# they are used in the future
|
||||
schema = QuantParamSchema(
|
||||
model_type=recovered_metadata["model_type"],
|
||||
kv_cache={
|
||||
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
|
||||
recovered_metadata["model_dtype"]),
|
||||
"scaling_factor":
|
||||
rank_scales_map
|
||||
},
|
||||
)
|
||||
|
||||
if args.output_dir is None:
|
||||
output_file = os.path.join(args.quantized_model, args.output_name)
|
||||
else:
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_file = os.path.join(args.output_dir, args.output_name)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schema.model_dump_json(indent=4))
|
||||
print(f"Completed! KV cache scaling factors saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This simple utility extracts the "
|
||||
"KV cache scaling factors from a quantized HF model "
|
||||
"and saves them to a JSON file compatible with later "
|
||||
"use by vLLM (pass this file to the appropriate "
|
||||
"runtime typically using the argument "
|
||||
"--quantization-param-path <filename>). This is only used "
|
||||
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
|
||||
parser.add_argument(
|
||||
"--quantized_model",
|
||||
help="Specify the directory containing a single quantized HF model. "
|
||||
"It is expected that the quantization format is FP8_E4M3, for use "
|
||||
"on ROCm (AMD GPU).",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--load_format",
|
||||
help="Optionally specify the format of the model's tensor files "
|
||||
"containing the KV cache scaling factors.",
|
||||
choices=["auto", "safetensors", "npz", "pt"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Optionally specify the output directory. By default the "
|
||||
"KV cache scaling factors will be saved in the model directory, "
|
||||
"however you can override this behavior here.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output_name",
|
||||
help="Optionally specify the output filename.",
|
||||
# TODO: Change this once additional scaling factors are enabled
|
||||
default="kv_cache_scales.json")
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
help="Optionally specify the tensor-parallel (TP) size that the "
|
||||
"quantized model should correspond to. If specified, during KV "
|
||||
"cache scaling factor extraction the observed TP size will be "
|
||||
"checked against this and an error will be raised if there is "
|
||||
"a mismatch. If not specified, the quantized model's expected "
|
||||
"TP size is instead inferred from the largest TP rank observed. "
|
||||
"The expected TP size is cross-checked against the TP ranks "
|
||||
"observed in the quantized model and an error is raised if any "
|
||||
"discrepancies are found.",
|
||||
default=None,
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
32
examples/fp8/quantizer/README.md
Normal file
32
examples/fp8/quantizer/README.md
Normal file
@@ -0,0 +1,32 @@
|
||||
### Quantizer Utilities
|
||||
`quantize.py`: NVIDIA Quantization utilities using AMMO, ported from TensorRT-LLM:
|
||||
`https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py`
|
||||
|
||||
### Prerequisite
|
||||
|
||||
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
|
||||
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
|
||||
|
||||
#### AMMO Download (code and docs)
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
|
||||
|
||||
### Usage
|
||||
|
||||
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
|
||||
|
||||
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
|
||||
`python quantize.py --model_dir ./ll2-7b --dtype float16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./ll2_7b_fp8 --calib_size 512 --tp_size 1`
|
||||
|
||||
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
|
||||
```
|
||||
# ll ./ll2_7b_fp8/
|
||||
total 19998244
|
||||
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
|
||||
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
|
||||
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
|
||||
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
|
||||
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
|
||||
#
|
||||
```
|
||||
|
||||
368
examples/fp8/quantizer/quantize.py
Normal file
368
examples/fp8/quantizer/quantize.py
Normal file
@@ -0,0 +1,368 @@
|
||||
# Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
import ammo.torch.quantization as atq
|
||||
import numpy as np
|
||||
import torch
|
||||
from ammo.torch.export import export_model_config
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
RAND_SEED = 1234
|
||||
MAX_SEQ_LEN = 2048
|
||||
|
||||
EMPTY_CFG = {
|
||||
"quant_cfg": {
|
||||
"*weight_quantizer": {
|
||||
"enable": False,
|
||||
},
|
||||
"*input_quantizer": {
|
||||
"enable": False
|
||||
},
|
||||
"*lm_head*": {
|
||||
"enable": False
|
||||
},
|
||||
"*output_layer*": {
|
||||
"enable": False
|
||||
},
|
||||
"default": {
|
||||
"enable": False
|
||||
},
|
||||
},
|
||||
"algorithm": "max",
|
||||
}
|
||||
|
||||
KV_CACHE_CFG = {
|
||||
"*.query_key_value.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.Wqkv.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.W_pack.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.c_attn.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.k_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.v_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
}
|
||||
|
||||
QUANT_CFG_CHOICES = {
|
||||
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
|
||||
"fp8": atq.FP8_DEFAULT_CFG,
|
||||
"int4_awq": atq.INT4_AWQ_CFG,
|
||||
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
|
||||
"int8_wo": EMPTY_CFG,
|
||||
"int4_wo": EMPTY_CFG,
|
||||
"full_prec": EMPTY_CFG,
|
||||
}
|
||||
|
||||
MODEL_NAME_PATTERN_MAP = {
|
||||
"GPT2": "gpt2",
|
||||
"Xverse": "llama",
|
||||
"Llama": "llama",
|
||||
"Mistral": "llama",
|
||||
"GPTJ": "gptj",
|
||||
"FalconForCausalLM": "falcon",
|
||||
"RWForCausalLM": "falcon",
|
||||
"baichuan": "baichuan",
|
||||
"MPT": "mpt",
|
||||
"Bloom": "bloom",
|
||||
"ChatGLM": "chatglm",
|
||||
"QWen": "qwen",
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
||||
print(f"Initializing tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
ckpt_path,
|
||||
model_max_length=max_seq_len,
|
||||
padding_side="left",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if model_type and model_type == "qwen":
|
||||
# qwen use token id 151643 as pad and eos tokens
|
||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
|
||||
# can't set attribute 'pad_token' for "<unk>"
|
||||
if tokenizer.pad_token != "<unk>":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
assert (tokenizer.pad_token
|
||||
is not None), f"Pad token for {model_type} cannot be set!"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
||||
print(f"Initializing model from {ckpt_path}")
|
||||
if dtype == "bf16" or dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16" or dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "fp32" or dtype == "float32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype {dtype}")
|
||||
|
||||
# model_kwargs = {"torch_dtype": dtype}
|
||||
model_kwargs = {"torch_dtype": "auto"}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
|
||||
device_map="auto",
|
||||
**model_kwargs,
|
||||
trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
if dtype != model_dtype:
|
||||
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||
f"{dtype}, but the data type of the HuggingFace model is "
|
||||
f"{model_dtype}.")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_model_type(model):
|
||||
for k, v in MODEL_NAME_PATTERN_MAP.items():
|
||||
if k.lower() in type(model).__name__.lower():
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512,
|
||||
device=None):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=block_size)
|
||||
if device:
|
||||
batch_encoded = batch_encoded.to(device)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||
|
||||
def calibrate_loop():
|
||||
if calib_dataloader is None:
|
||||
return
|
||||
"""Adjusts weights and scaling factors based on selected algorithms."""
|
||||
for idx, data in enumerate(calib_dataloader):
|
||||
print(f"Calibrating batch {idx}")
|
||||
model(data)
|
||||
|
||||
print("Starting quantization...")
|
||||
start_time = time.time()
|
||||
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
end_time = time.time()
|
||||
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
|
||||
start_time))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
random.seed(RAND_SEED)
|
||||
np.random.seed(RAND_SEED)
|
||||
|
||||
model = get_model(args.model_dir, args.dtype, args.device)
|
||||
model_type = get_model_type(model)
|
||||
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
|
||||
|
||||
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
|
||||
] and args.kv_cache_dtype is None:
|
||||
print(f"No quantization applied, export {args.dtype} model")
|
||||
else:
|
||||
if "awq" in args.qformat:
|
||||
if args.calib_size > 32:
|
||||
print("AWQ calibration could take longer with calib_size = "
|
||||
f"{args.calib_size}, Using calib_size=32 instead")
|
||||
args.calib_size = 32
|
||||
print("\nAWQ calibration could take longer than other calibration "
|
||||
"methods. Please increase the batch size to speed up the "
|
||||
"calibration process. Batch size can be set by adding the "
|
||||
"argument --batch_size <batch_size> to the command line.\n")
|
||||
|
||||
calib_dataloader = get_calib_dataloader(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
calib_size=args.calib_size,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.qformat in QUANT_CFG_CHOICES:
|
||||
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization format: {args.qformat}")
|
||||
|
||||
if "awq" in args.qformat:
|
||||
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
|
||||
weight_quantizer = quant_cfg["quant_cfg"][
|
||||
"*weight_quantizer"] # type: ignore
|
||||
if isinstance(weight_quantizer, list):
|
||||
weight_quantizer = weight_quantizer[0]
|
||||
weight_quantizer["block_sizes"][-1] = args.awq_block_size
|
||||
|
||||
if args.kv_cache_dtype is not None:
|
||||
if args.kv_cache_dtype == "fp8":
|
||||
for value in KV_CACHE_CFG.values():
|
||||
value.update({"num_bits": (4, 3)}) # type: ignore
|
||||
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
|
||||
|
||||
print(quant_cfg)
|
||||
|
||||
model = quantize_model(model, quant_cfg, calib_dataloader)
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type is None:
|
||||
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||
"exporting...")
|
||||
model_type = f"unknown:{type(model).__name__}"
|
||||
|
||||
export_path = args.output_dir
|
||||
start_time = time.time()
|
||||
|
||||
if args.qformat == "int4_awq" and model_type == "qwen":
|
||||
torch.save(model.state_dict(), export_path)
|
||||
else:
|
||||
export_npz = (model_type not in [
|
||||
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
|
||||
])
|
||||
|
||||
# export safetensors
|
||||
export_model_config(
|
||||
model,
|
||||
model_type,
|
||||
getattr(torch, args.dtype),
|
||||
export_dir=export_path,
|
||||
inference_tensor_parallel=args.tp_size,
|
||||
inference_pipeline_parallel=args.pp_size,
|
||||
# export_tensorrt_llm_config=(not export_npz),
|
||||
export_tensorrt_llm_config=False,
|
||||
export_npz=export_npz)
|
||||
|
||||
# Workaround for wo quantization
|
||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||
with open(f"{export_path}/config.json", 'r') as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
if args.qformat == "int8_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||
elif args.qformat == "int4_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
|
||||
else:
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = None
|
||||
with open(f"{export_path}/config.json", "w") as f:
|
||||
json.dump(tensorrt_llm_config, f, indent=4)
|
||||
|
||||
end_time = time.time()
|
||||
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
|
||||
format(export_path, end_time - start_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model_dir",
|
||||
help="Specify where the HuggingFace model is",
|
||||
required=True)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
help="Quantization format.",
|
||||
default="full_prec",
|
||||
choices=[
|
||||
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
|
||||
"full_prec"
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size",
|
||||
help="Batch size for calibration.",
|
||||
type=int,
|
||||
default=1)
|
||||
parser.add_argument("--calib_size",
|
||||
help="Number of samples for calibration.",
|
||||
type=int,
|
||||
default=512)
|
||||
parser.add_argument("--output_dir", default="exported_model")
|
||||
parser.add_argument("--tp_size", type=int, default=1)
|
||||
parser.add_argument("--pp_size", type=int, default=1)
|
||||
parser.add_argument("--awq_block_size", type=int, default=128)
|
||||
parser.add_argument("--kv_cache_dtype",
|
||||
help="KV Cache dtype.",
|
||||
default=None,
|
||||
choices=["int8", "fp8", None])
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
82
examples/gradio_openai_chatbot_webserver.py
Normal file
82
examples/gradio_openai_chatbot_webserver.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
from openai import OpenAI
|
||||
|
||||
# Argument parser setup
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = args.model_url
|
||||
|
||||
# Create an OpenAI client to interact with the API server
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
|
||||
def predict(message, history):
|
||||
# Convert chat history to OpenAI format
|
||||
history_openai_format = [{
|
||||
"role": "system",
|
||||
"content": "You are a great ai assistant."
|
||||
}]
|
||||
for human, assistant in history:
|
||||
history_openai_format.append({"role": "user", "content": human})
|
||||
history_openai_format.append({
|
||||
"role": "assistant",
|
||||
"content": assistant
|
||||
})
|
||||
history_openai_format.append({"role": "user", "content": message})
|
||||
|
||||
# Create a chat completion request and send it to the API server
|
||||
stream = client.chat.completions.create(
|
||||
model=args.model, # Model name to use
|
||||
messages=history_openai_format, # Chat history
|
||||
temperature=args.temp, # Temperature for text generation
|
||||
stream=True, # Stream response
|
||||
extra_body={
|
||||
'repetition_penalty':
|
||||
1,
|
||||
'stop_token_ids': [
|
||||
int(id.strip()) for id in args.stop_token_ids.split(',')
|
||||
if id.strip()
|
||||
] if args.stop_token_ids else []
|
||||
})
|
||||
|
||||
# Read and return generated text from response stream
|
||||
partial_message = ""
|
||||
for chunk in stream:
|
||||
partial_message += (chunk.choices[0].delta.content or "")
|
||||
yield partial_message
|
||||
|
||||
|
||||
# Create and launch a chat interface with Gradio
|
||||
gr.ChatInterface(predict).queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
52
examples/gradio_webserver.py
Normal file
52
examples/gradio_webserver.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
|
||||
|
||||
def http_bot(prompt):
|
||||
headers = {"User-Agent": "vLLM Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"max_tokens": 128,
|
||||
}
|
||||
response = requests.post(args.model_url,
|
||||
headers=headers,
|
||||
json=pload,
|
||||
stream=True)
|
||||
|
||||
for chunk in response.iter_lines(chunk_size=8192,
|
||||
decode_unicode=False,
|
||||
delimiter=b"\0"):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["text"][0]
|
||||
yield output
|
||||
|
||||
|
||||
def build_demo():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# vLLM text completion demo\n")
|
||||
inputbox = gr.Textbox(label="Input",
|
||||
placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output",
|
||||
placeholder="Generated result from the model")
|
||||
inputbox.submit(http_bot, [inputbox], [outputbox])
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
|
||||
demo = build_demo()
|
||||
demo.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
90
examples/llava_example.py
Normal file
90
examples/llava_example.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
|
||||
|
||||
|
||||
def run_llava_pixel_values():
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_input_type="pixel_values",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,3,336,336",
|
||||
image_feature_size=576,
|
||||
)
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
images = torch.load("images/stop_sign_pixel_values.pt")
|
||||
|
||||
outputs = llm.generate(prompt,
|
||||
multi_modal_data=MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE, data=images))
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def run_llava_image_features():
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
image_input_type="image_features",
|
||||
image_token_id=32000,
|
||||
image_input_shape="1,576,1024",
|
||||
image_feature_size=576,
|
||||
)
|
||||
|
||||
prompt = "<image>" * 576 + (
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
images = torch.load("images/stop_sign_image_features.pt")
|
||||
|
||||
outputs = llm.generate(prompt,
|
||||
multi_modal_data=MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE, data=images))
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.type == "pixel_values":
|
||||
run_llava_pixel_values()
|
||||
else:
|
||||
run_llava_image_features()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Demo on Llava")
|
||||
parser.add_argument("--type",
|
||||
type=str,
|
||||
choices=["pixel_values", "image_features"],
|
||||
default="pixel_values",
|
||||
help="image input type")
|
||||
args = parser.parse_args()
|
||||
# Download from s3
|
||||
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
|
||||
local_directory = "images"
|
||||
|
||||
# Make sure the local directory exists or create it
|
||||
os.makedirs(local_directory, exist_ok=True)
|
||||
|
||||
# Use AWS CLI to sync the directory, assume anonymous access
|
||||
subprocess.check_call([
|
||||
"aws",
|
||||
"s3",
|
||||
"sync",
|
||||
s3_bucket_path,
|
||||
local_directory,
|
||||
"--no-sign-request",
|
||||
])
|
||||
main(args)
|
||||
62
examples/llm_engine_example.py
Normal file
62
examples/llm_engine_example.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
from typing import List, Tuple
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
|
||||
|
||||
def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
|
||||
"""Create a list of test prompts with their sampling parameters."""
|
||||
return [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1)),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2,
|
||||
best_of=5,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0.1)),
|
||||
("It is only with the heart that one can see rightly",
|
||||
SamplingParams(n=3, best_of=3, use_beam_search=True,
|
||||
temperature=0.0)),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: List[Tuple[str, SamplingParams]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: List[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
|
||||
|
||||
def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
"""Initialize the LLMEngine from the command line arguments."""
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
test_prompts = create_test_prompts()
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
178
examples/logging_configuration.md
Normal file
178
examples/logging_configuration.md
Normal file
@@ -0,0 +1,178 @@
|
||||
# Logging Configuration
|
||||
|
||||
vLLM leverages Python's `logging.config.dictConfig` functionality to enable
|
||||
robust and flexible configuration of the various loggers used by vLLM.
|
||||
|
||||
vLLM offers two environment variables that can be used to accommodate a range
|
||||
of logging configurations that range from simple-and-inflexible to
|
||||
more-complex-and-more-flexible.
|
||||
|
||||
- No vLLM logging (simple and inflexible)
|
||||
- Set `VLLM_CONFIGURE_LOGGING=0` (leaving `VLLM_LOGGING_CONFIG_PATH` unset)
|
||||
- vLLM's default logging configuration (simple and inflexible)
|
||||
- Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1`
|
||||
- Fine-grained custom logging configuration (more complex, more flexible)
|
||||
- Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and
|
||||
set `VLLM_LOGGING_CONFIG_PATH=<path-to-logging-config.json>`
|
||||
|
||||
|
||||
## Logging Configuration Environment Variables
|
||||
|
||||
### `VLLM_CONFIGURE_LOGGING`
|
||||
|
||||
`VLLM_CONFIGURE_LOGGING` controls whether or not vLLM takes any action to
|
||||
configure the loggers used by vLLM. This functionality is enabled by default,
|
||||
but can be disabled by setting `VLLM_CONFIGURE_LOGGING=0` when running vLLM.
|
||||
|
||||
If `VLLM_CONFIGURE_LOGGING` is enabled and no value is given for
|
||||
`VLLM_LOGGING_CONFIG_PATH`, vLLM will use built-in default configuration to
|
||||
configure the root vLLM logger. By default, no other vLLM loggers are
|
||||
configured and, as such, all vLLM loggers defer to the root vLLM logger to make
|
||||
all logging decisions.
|
||||
|
||||
If `VLLM_CONFIGURE_LOGGING` is disabled and a value is given for
|
||||
`VLLM_LOGGING_CONFIG_PATH`, an error will occur while starting vLLM.
|
||||
|
||||
### `VLLM_LOGGING_CONFIG_PATH`
|
||||
|
||||
`VLLM_LOGGING_CONFIG_PATH` allows users to specify a path to a JSON file of
|
||||
alternative, custom logging configuration that will be used instead of vLLM's
|
||||
built-in default logging configuration. The logging configuration should be
|
||||
provided in JSON format following the schema specified by Python's [logging
|
||||
configuration dictionary
|
||||
schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details).
|
||||
|
||||
If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is
|
||||
disabled, an error will occur while starting vLLM.
|
||||
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Customize vLLM root logger
|
||||
|
||||
For this example, we will customize the vLLM root logger to use
|
||||
[`python-json-logger`](https://github.com/madzak/python-json-logger) to log to
|
||||
STDOUT of the console in JSON format with a log level of `INFO`.
|
||||
|
||||
To begin, first, create an appropriate JSON logging configuration file:
|
||||
|
||||
**/path/to/logging_config.json:**
|
||||
|
||||
```json
|
||||
{
|
||||
"formatters": {
|
||||
"json": {
|
||||
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class" : "logging.StreamHandler",
|
||||
"formatter": "json",
|
||||
"level": "INFO",
|
||||
"stream": "ext://sys.stdout"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"vllm": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
"propagate": false
|
||||
}
|
||||
},
|
||||
"version": 1
|
||||
}
|
||||
```
|
||||
|
||||
Next, install the `python-json-logger` package if it's not already installed:
|
||||
|
||||
```bash
|
||||
pip install python-json-logger
|
||||
```
|
||||
|
||||
Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set
|
||||
to the path of the custom logging configuration JSON file:
|
||||
|
||||
```bash
|
||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||
python3 -m vllm.entrypoints.openai.api_server \
|
||||
--max-model-len 2048 \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
|
||||
### Example 2: Silence a particular vLLM logger
|
||||
|
||||
To silence a particular vLLM logger, it is necessary to provide custom logging
|
||||
configuration for the target logger that configures the logger so that it won't
|
||||
propagate its log messages to the root vLLM logger.
|
||||
|
||||
When custom configuration is provided for any logger, it is also necessary to
|
||||
provide configuration for the root vLLM logger since any custom logger
|
||||
configuration overrides the built-in default logging configuration used by vLLM.
|
||||
|
||||
First, create an appropriate JSON logging configuration file that includes
|
||||
configuration for the root vLLM logger and for the logger you wish to silence:
|
||||
|
||||
**/path/to/logging_config.json:**
|
||||
|
||||
```json
|
||||
{
|
||||
"formatters": {
|
||||
"vllm": {
|
||||
"class": "vllm.logging.NewLineFormatter",
|
||||
"datefmt": "%m-%d %H:%M:%S",
|
||||
"format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"vllm": {
|
||||
"class" : "logging.StreamHandler",
|
||||
"formatter": "vllm",
|
||||
"level": "INFO",
|
||||
"stream": "ext://sys.stdout"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"vllm": {
|
||||
"handlers": ["vllm"],
|
||||
"level": "DEBUG",
|
||||
"propagage": false
|
||||
},
|
||||
"vllm.example_noisy_logger": {
|
||||
"propagate": false
|
||||
}
|
||||
},
|
||||
"version": 1
|
||||
}
|
||||
```
|
||||
|
||||
Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set
|
||||
to the path of the custom logging configuration JSON file:
|
||||
|
||||
```bash
|
||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||
python3 -m vllm.entrypoints.openai.api_server \
|
||||
--max-model-len 2048 \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
|
||||
### Example 3: Disable vLLM default logging configuration
|
||||
|
||||
To disable vLLM's default logging configuration and silence all vLLM loggers,
|
||||
simple set `VLLM_CONFIGURE_LOGGING=0` when running vLLM. This will prevent vLLM
|
||||
for configuring the root vLLM logger, which in turn, silences all other vLLM
|
||||
loggers.
|
||||
|
||||
```bash
|
||||
VLLM_CONFIGURE_LOGGING=0 \
|
||||
python3 -m vllm.entrypoints.openai.api_server \
|
||||
--max-model-len 2048 \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
|
||||
## Additional resources
|
||||
|
||||
- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details)
|
||||
124
examples/multilora_inference.py
Normal file
124
examples/multilora_inference.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
This example shows how to use the multi-LoRA functionality
|
||||
for offline inference.
|
||||
|
||||
Requires HuggingFace credentials for access to Llama2.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(
|
||||
lora_path: str
|
||||
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||
"""Create a list of test prompts with their sampling parameters.
|
||||
|
||||
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||
different LoRA adapters (using the same model for demo purposes).
|
||||
Since we also set `max_loras=1`, the expectation is that the requests
|
||||
with the second LoRA adapter will be ran after all requests with the
|
||||
first adapter have finished.
|
||||
"""
|
||||
return [
|
||||
("A robot may not injure a human being",
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128), None),
|
||||
("To be or not to be,",
|
||||
SamplingParams(temperature=0.8,
|
||||
top_k=5,
|
||||
presence_penalty=0.2,
|
||||
max_tokens=128), None),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora", 1, lora_path)),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(n=3,
|
||||
best_of=3,
|
||||
use_beam_search=True,
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora", 1, lora_path)),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(temperature=0.0,
|
||||
logprobs=1,
|
||||
prompt_logprobs=1,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora2", 2, lora_path)),
|
||||
(
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
|
||||
SamplingParams(n=3,
|
||||
best_of=3,
|
||||
use_beam_search=True,
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
stop_token_ids=[32003]),
|
||||
LoRARequest("sql-lora", 1, lora_path)),
|
||||
]
|
||||
|
||||
|
||||
def process_requests(engine: LLMEngine,
|
||||
test_prompts: List[Tuple[str, SamplingParams,
|
||||
Optional[LoRARequest]]]):
|
||||
"""Continuously process a list of prompts and handle the outputs."""
|
||||
request_id = 0
|
||||
|
||||
while test_prompts or engine.has_unfinished_requests():
|
||||
if test_prompts:
|
||||
prompt, sampling_params, lora_request = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id),
|
||||
prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
request_id += 1
|
||||
|
||||
request_outputs: List[RequestOutput] = engine.step()
|
||||
|
||||
for request_output in request_outputs:
|
||||
if request_output.finished:
|
||||
print(request_output)
|
||||
|
||||
|
||||
def initialize_engine() -> LLMEngine:
|
||||
"""Initialize the LLMEngine."""
|
||||
# max_loras: controls the number of LoRAs that can be used in the same
|
||||
# batch. Larger numbers will cause higher memory usage, as each LoRA
|
||||
# slot requires its own preallocated tensor.
|
||||
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
|
||||
# numbers will cause higher memory usage. If you know that all LoRAs will
|
||||
# use the same rank, it is recommended to set this as low as possible.
|
||||
# max_cpu_loras: controls the size of the CPU LoRA cache.
|
||||
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=256)
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine()
|
||||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
process_requests(engine, test_prompts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
22
examples/offline_inference.py
Normal file
22
examples/offline_inference.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
72
examples/offline_inference_distributed.py
Normal file
72
examples/offline_inference_distributed.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
This example shows how to use Ray Data for running offline batch inference
|
||||
distributively on a multi-nodes cluster.
|
||||
|
||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
|
||||
# Create a class to do batch inference.
|
||||
class LLMPredictor:
|
||||
|
||||
def __init__(self):
|
||||
# Create an LLM.
|
||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
|
||||
|
||||
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects that contain the prompt,
|
||||
# generated text, and other information.
|
||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||
prompt = []
|
||||
generated_text = []
|
||||
for output in outputs:
|
||||
prompt.append(output.prompt)
|
||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
|
||||
# Apply batch inference for all input data.
|
||||
ds = ds.map_batches(
|
||||
LLMPredictor,
|
||||
# Set the concurrency to the number of LLM instances.
|
||||
concurrency=10,
|
||||
# Specify the number of GPUs required per LLM instance.
|
||||
# NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
|
||||
# (i.e., `tensor_parallel_size`).
|
||||
num_gpus=1,
|
||||
# Specify the batch size for inference.
|
||||
batch_size=32,
|
||||
)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
||||
36
examples/offline_inference_neuron.py
Executable file
36
examples/offline_inference_neuron.py
Executable file
@@ -0,0 +1,36 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=128,
|
||||
block_size=128,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
tensor_parallel_size=2)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
53
examples/offline_inference_with_prefix.py
Normal file
53
examples/offline_inference_with_prefix.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prefix = (
|
||||
"You are an expert school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: ")
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.0)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", enable_prefix_caching=True)
|
||||
|
||||
generating_prompts = [prefix + prompt for prompt in prompts]
|
||||
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(generating_prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# The llm.generate call will batch all prompts and send the batch at once
|
||||
# if resources allow. The prefix will only be cached after the first batch
|
||||
# is processed, so we need to call generate once to calculate the prefix
|
||||
# and cache it.
|
||||
outputs = llm.generate(generating_prompts[0], sampling_params)
|
||||
|
||||
# Subsequent batches can leverage the cached prefix
|
||||
outputs = llm.generate(generating_prompts, sampling_params)
|
||||
|
||||
# Print the outputs. You should see the same outputs as before
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
36
examples/openai_chat_completion_client.py
Normal file
36
examples/openai_chat_completion_client.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
model=model,
|
||||
)
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
31
examples/openai_completion_client.py
Normal file
31
examples/openai_completion_client.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Completion API
|
||||
stream = False
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
|
||||
print("Completion results:")
|
||||
if stream:
|
||||
for c in completion:
|
||||
print(c)
|
||||
else:
|
||||
print(completion)
|
||||
54
examples/production_monitoring/README.md
Normal file
54
examples/production_monitoring/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# vLLM + Prometheus/Grafana
|
||||
|
||||
This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites.
|
||||
|
||||
Install:
|
||||
- [`docker`](https://docs.docker.com/engine/install/)
|
||||
- [`docker compose`](https://docs.docker.com/compose/install/linux/#install-using-the-repository)
|
||||
|
||||
### Launch
|
||||
|
||||
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
|
||||
```bash
|
||||
python3 -m vllm.entrypoints.openai.api_server \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--max-model-len 2048 \
|
||||
--disable-log-requests
|
||||
```
|
||||
|
||||
Launch Prometheus and Grafana servers with `docker compose`:
|
||||
```bash
|
||||
docker compose up
|
||||
```
|
||||
|
||||
Submit some sample requests to the server:
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
python3 ../../benchmarks/benchmark_serving.py \
|
||||
--model mistralai/Mistral-7B-v0.1 \
|
||||
--tokenizer mistralai/Mistral-7B-v0.1 \
|
||||
--endpoint /v1/completions \
|
||||
--dataset ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--request-rate 3.0
|
||||
```
|
||||
|
||||
Navigating to [`http://localhost:8000/metrics`](http://localhost:8000/metrics) will show the raw Prometheus metrics being exposed by vLLM.
|
||||
|
||||
### Grafana Dashboard
|
||||
|
||||
Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the default username (`admin`) and password (`admin`).
|
||||
|
||||
#### Add Prometheus Data Source
|
||||
|
||||
Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus.
|
||||
|
||||
On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`.
|
||||
|
||||
Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.".
|
||||
|
||||
#### Import Dashboard
|
||||
|
||||
Navigate to [`http://localhost:3000/dashboard/import`](http://localhost:3000/dashboard/import), upload `grafana.json`, and select the `prometheus` datasource. You should see a screen that looks like the following:
|
||||
|
||||

|
||||
19
examples/production_monitoring/docker-compose.yaml
Normal file
19
examples/production_monitoring/docker-compose.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# docker-compose.yaml
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway" # allow a direct connection from container to the local machine
|
||||
ports:
|
||||
- "9090:9090" # the default port used by Prometheus
|
||||
volumes:
|
||||
- ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml # mount Prometheus config file
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
depends_on:
|
||||
- prometheus
|
||||
ports:
|
||||
- "3000:3000" # the default port used by Grafana
|
||||
1206
examples/production_monitoring/grafana.json
Normal file
1206
examples/production_monitoring/grafana.json
Normal file
File diff suppressed because it is too large
Load Diff
10
examples/production_monitoring/prometheus.yaml
Normal file
10
examples/production_monitoring/prometheus.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
# prometheus.yaml
|
||||
global:
|
||||
scrape_interval: 5s
|
||||
evaluation_interval: 30s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: vllm
|
||||
static_configs:
|
||||
- targets:
|
||||
- 'host.docker.internal:8000'
|
||||
29
examples/template_alpaca.jinja
Normal file
29
examples/template_alpaca.jinja
Normal file
@@ -0,0 +1,29 @@
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
### Instruction:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
### Response:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
### Input:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
### Response:
|
||||
{% endif %}
|
||||
13
examples/template_baichuan.jinja
Normal file
13
examples/template_baichuan.jinja
Normal file
@@ -0,0 +1,13 @@
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '<reserved_106>' + message['content'] -}}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{- '<reserved_107>' + message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
||||
{{- '<reserved_107>' -}}
|
||||
{% endif %}
|
||||
18
examples/template_chatglm.jinja
Normal file
18
examples/template_chatglm.jinja
Normal file
@@ -0,0 +1,18 @@
|
||||
{%- set counter = namespace(index=0) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '[Round ' + counter.index|string + ']\n问:' + message['content'] -}}
|
||||
{%- set counter.index = counter.index + 1 -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' -%}
|
||||
{{- '\n答:' + message['content'] -}}
|
||||
{%- if (loop.last and add_generation_prompt) or not loop.last -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
||||
{{- '\n答:' -}}
|
||||
{%- endif -%}
|
||||
18
examples/template_chatglm2.jinja
Normal file
18
examples/template_chatglm2.jinja
Normal file
@@ -0,0 +1,18 @@
|
||||
{%- set counter = namespace(index=1) -%}
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- '[Round ' + counter.index|string + ']\n\n问:' + message['content'] -}}
|
||||
{%- set counter.index = counter.index + 1 -%}
|
||||
{%- endif -%}
|
||||
{%- if message['role'] == 'assistant' -%}
|
||||
{{- '\n\n答:' + message['content'] -}}
|
||||
{%- if (loop.last and add_generation_prompt) or not loop.last -%}
|
||||
{{- '\n\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
||||
{{- '\n\n答:' -}}
|
||||
{%- endif -%}
|
||||
2
examples/template_chatml.jinja
Normal file
2
examples/template_chatml.jinja
Normal file
@@ -0,0 +1,2 @@
|
||||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
15
examples/template_falcon.jinja
Normal file
15
examples/template_falcon.jinja
Normal file
@@ -0,0 +1,15 @@
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'user' -%}
|
||||
{{- 'User: ' + message['content'] -}}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{- 'Assistant: ' + message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if (loop.last and add_generation_prompt) or not loop.last -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
||||
{{- 'Assistant:' -}}
|
||||
{% endif %}
|
||||
17
examples/template_falcon_180b.jinja
Normal file
17
examples/template_falcon_180b.jinja
Normal file
@@ -0,0 +1,17 @@
|
||||
{%- for message in messages -%}
|
||||
{%- if message['role'] == 'system' -%}
|
||||
{{- 'System: ' + message['content'] -}}
|
||||
{%- elif message['role'] == 'user' -%}
|
||||
{{- 'User: ' + message['content'] -}}
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{- 'Falcon: ' + message['content'] -}}
|
||||
{%- endif -%}
|
||||
{%- if (loop.last and add_generation_prompt) or not loop.last -%}
|
||||
{{- '\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
|
||||
{%- if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
|
||||
{{- 'Falcon:' -}}
|
||||
{% endif %}
|
||||
30
examples/template_inkbot.jinja
Normal file
30
examples/template_inkbot.jinja
Normal file
@@ -0,0 +1,30 @@
|
||||
<#meta#>
|
||||
- Date: {{ (messages|selectattr('role', 'equalto', 'meta-current_date')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-current_date')|list) else '' }}
|
||||
- Task: {{ (messages|selectattr('role', 'equalto', 'meta-task_name')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'meta-task_name')|list) else '' }}
|
||||
<#system#>
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
<#chat#>
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
<#user#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
<#bot#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
<#user_context#>
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
<#bot#>
|
||||
{% endif %}
|
||||
282
examples/tensorize_vllm_model.py
Normal file
282
examples/tensorize_vllm_model.py
Normal file
@@ -0,0 +1,282 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||
TensorSerializer, stream_io)
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
"""
|
||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||
deserialize vLLM models. These models can be loaded using tensorizer
|
||||
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||
or locally. Tensor encryption and decryption is also supported, although
|
||||
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||
using `pip install vllm[tensorizer]`.
|
||||
|
||||
To serialize a model, install vLLM from source, then run something
|
||||
like this from the root level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket/ \
|
||||
--suffix vllm
|
||||
|
||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||
and saves it to your S3 bucket. A local directory can also be used. This
|
||||
assumes your S3 credentials are specified as environment variables
|
||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
|
||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
|
||||
script.
|
||||
|
||||
You can also encrypt the model weights with a randomly-generated key by
|
||||
providing a `--keyfile` argument.
|
||||
|
||||
To deserialize a model, you can run something like this from the root
|
||||
level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||
|
||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||
|
||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||
they were serialized with encryption.
|
||||
|
||||
For more information on the available arguments for serializing, run
|
||||
`python -m examples.tensorize_vllm_model serialize --help`.
|
||||
|
||||
Or for deserializing:
|
||||
|
||||
`python -m examples.tensorize_vllm_model deserialize --help`.
|
||||
|
||||
Once a model is serialized, it can be used to load the model when running the
|
||||
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
|
||||
the `--tensorizer-uri` CLI argument that is functionally the same as the
|
||||
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
|
||||
signify that the model to be deserialized is a vLLM model, rather than a
|
||||
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
|
||||
in the same inference server, albeit without the speed optimizations. To
|
||||
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
|
||||
to provide the path to the keyfile used to encrypt the model weights. For
|
||||
information on all the arguments that can be used to configure tensorizer's
|
||||
deserialization, check out the tensorizer options argument group in the
|
||||
`vllm/entrypoints/openai/api_server.py` script with `--help`.
|
||||
|
||||
Tensorizer can also be invoked with the `LLM` class directly to load models:
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
load_format="tensorizer",
|
||||
tensorizer_uri=path_to_opt_tensors,
|
||||
num_readers=3,
|
||||
vllm_tensorized=True)
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"The suffix to append to the serialized model directory, which is "
|
||||
"used to construct the location of the serialized model tensors, "
|
||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||
"saved to "
|
||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||
"If none is provided, a random UUID will be used."))
|
||||
serialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The directory to serialize the model to. "
|
||||
"This can be a local directory or S3 URI. The path to where the "
|
||||
"tensors are saved is a combination of the supplied `dir` and model "
|
||||
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||
"provided.")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||
" and save the key at this path"))
|
||||
|
||||
deserialize_parser = subparsers.add_parser(
|
||||
'deserialize',
|
||||
help=("Deserialize a model from `--path-to-tensors`"
|
||||
" to verify it can be loaded and used."))
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"))
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_model_contiguous(model):
|
||||
# Ensure tensors are saved in memory contiguously
|
||||
for param in model.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
|
||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def serialize():
|
||||
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
model = (engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random() if keyfile else None
|
||||
if keyfile:
|
||||
with _write_stream(keyfile) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
with _write_stream(model_path) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
print("Serialization complete. Model tensors saved to", model_path)
|
||||
if keyfile:
|
||||
print("Key saved to", keyfile)
|
||||
|
||||
|
||||
def deserialize():
|
||||
config = AutoConfig.from_pretrained(model_ref)
|
||||
|
||||
with no_init_or_tensor():
|
||||
model_class = _get_vllm_model_architecture(config)
|
||||
model = model_class(config)
|
||||
|
||||
before_mem = get_mem_usage()
|
||||
start = time.time()
|
||||
|
||||
if keyfile:
|
||||
with _read_stream(keyfile) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
tensorizer_args.deserializer_params['encryption'] = \
|
||||
decryption_params
|
||||
|
||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.time()
|
||||
|
||||
# Brag about how fast we are.
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
print(
|
||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||
)
|
||||
print(f"Memory usage before: {before_mem}")
|
||||
print(f"Memory usage after: {after_mem}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||
or None)
|
||||
s3_secret_access_key = (args.s3_secret_access_key
|
||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||
|
||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
stream_io.open_stream,
|
||||
mode=mode,
|
||||
s3_access_key_id=s3_access_key_id,
|
||||
s3_secret_access_key=s3_secret_access_key,
|
||||
s3_endpoint=s3_endpoint,
|
||||
) for mode in ("rb", "wb+"))
|
||||
|
||||
model_ref = args.model
|
||||
|
||||
model_name = model_ref.split("/")[1]
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "8080"
|
||||
|
||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||
initialize_model_parallel()
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
|
||||
if args.command == "serialize":
|
||||
input_dir = args.serialized_directory.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
model_path = f"{base_path}/model.tensors"
|
||||
serialize()
|
||||
elif args.command == "deserialize":
|
||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||
model_path = args.path_to_tensors
|
||||
deserialize()
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
||||
Reference in New Issue
Block a user