Sync from v0.13
This commit is contained in:
60
examples/others/lmcache/README.md
Normal file
60
examples/others/lmcache/README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# LMCache Examples
|
||||
|
||||
This folder demonstrates how to use LMCache for disaggregated prefilling, CPU offloading and KV cache sharing.
|
||||
|
||||
## 1. Disaggregated Prefill in vLLM v1
|
||||
|
||||
This example demonstrates how to run LMCache with disaggregated prefill using NIXL on a single node.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Install [LMCache](https://github.com/LMCache/LMCache). You can simply run `pip install lmcache`.
|
||||
- Install [NIXL](https://github.com/ai-dynamo/nixl).
|
||||
- At least 2 GPUs
|
||||
- Valid Hugging Face token (HF_TOKEN) for Llama 3.1 8B Instruct.
|
||||
|
||||
### Usage
|
||||
|
||||
Run
|
||||
`cd disagg_prefill_lmcache_v1`
|
||||
to get into `disagg_prefill_lmcache_v1` folder, and then run
|
||||
|
||||
```bash
|
||||
bash disagg_example_nixl.sh
|
||||
```
|
||||
|
||||
to run disaggregated prefill and benchmark the performance.
|
||||
|
||||
### Components
|
||||
|
||||
#### Server Scripts
|
||||
|
||||
- `disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh` - Launches individual vLLM servers for prefill/decode, and also launches the proxy server.
|
||||
- `disagg_prefill_lmcache_v1/disagg_proxy_server.py` - FastAPI proxy server that coordinates between prefiller and decoder
|
||||
- `disagg_prefill_lmcache_v1/disagg_example_nixl.sh` - Main script to run the example
|
||||
|
||||
#### Configuration
|
||||
|
||||
- `disagg_prefill_lmcache_v1/configs/lmcache-prefiller-config.yaml` - Configuration for prefiller server
|
||||
- `disagg_prefill_lmcache_v1/configs/lmcache-decoder-config.yaml` - Configuration for decoder server
|
||||
|
||||
#### Log Files
|
||||
|
||||
The main script generates several log files:
|
||||
|
||||
- `prefiller.log` - Logs from the prefill server
|
||||
- `decoder.log` - Logs from the decode server
|
||||
- `proxy.log` - Logs from the proxy server
|
||||
|
||||
## 2. CPU Offload Examples
|
||||
|
||||
- `python cpu_offload_lmcache.py -v v0` - CPU offloading implementation for vLLM v0
|
||||
- `python cpu_offload_lmcache.py -v v1` - CPU offloading implementation for vLLM v1
|
||||
|
||||
## 3. KV Cache Sharing
|
||||
|
||||
The `kv_cache_sharing_lmcache_v1.py` example demonstrates how to share KV caches between vLLM v1 instances.
|
||||
|
||||
## 4. Disaggregated Prefill in vLLM v0
|
||||
|
||||
The `disaggregated_prefill_lmcache_v0.py` provides an example of how to run disaggregated prefill in vLLM v0.
|
||||
134
examples/others/lmcache/cpu_offload_lmcache.py
Normal file
134
examples/others/lmcache/cpu_offload_lmcache.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of cpu offloading
|
||||
with LMCache in vLLM v1 or v0.
|
||||
|
||||
Usage:
|
||||
|
||||
Specify vLLM version
|
||||
|
||||
-v v0 : Use LMCacheConnector
|
||||
model = mistralai/Mistral-7B-Instruct-v0.2
|
||||
(Includes enable_chunked_prefill = True)
|
||||
|
||||
-v v1 : Use LMCacheConnectorV1 (default)
|
||||
model = meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
(Without enable_chunked_prefill)
|
||||
|
||||
Note that `lmcache` is needed to run this example.
|
||||
Requirements:
|
||||
https://docs.lmcache.ai/getting_started/installation.html#prerequisites
|
||||
Learn more about LMCache environment setup, please refer to:
|
||||
https://docs.lmcache.ai/getting_started/installation.html
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from lmcache.v1.cache_engine import LMCacheEngineBuilder
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
def setup_environment_variables():
|
||||
# LMCache-related environment variables
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
# LMCache is set to use 256 tokens per chunk
|
||||
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||
# Enable local CPU backend in LMCache
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "True"
|
||||
# Set local CPU memory limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def build_llm_with_lmcache(lmcache_connector: str, model: str):
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector=lmcache_connector,
|
||||
kv_role="kv_both",
|
||||
)
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
|
||||
llm_args = EngineArgs(
|
||||
model=model,
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
llm = LLM(**asdict(llm_args))
|
||||
try:
|
||||
yield llm
|
||||
finally:
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
|
||||
def print_output(
|
||||
llm: LLM,
|
||||
prompt: list[str],
|
||||
sampling_params: SamplingParams,
|
||||
req_str: str,
|
||||
):
|
||||
# Should be able to see logs like the following:
|
||||
# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0`
|
||||
# This indicates that the KV cache has been stored in LMCache.
|
||||
start = time.time()
|
||||
outputs = llm.generate(prompt, sampling_params)
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
choices=["v0", "v1"],
|
||||
default="v1",
|
||||
help="Specify vLLM version (default: v1)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
lmcache_connector = "LMCacheConnectorV1"
|
||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
setup_environment_variables()
|
||||
with build_llm_with_lmcache(lmcache_connector, model) as llm:
|
||||
# This example script runs two requests with a shared prefix.
|
||||
# Define the shared prompt and specific prompts
|
||||
shared_prompt = "Hello, how are you?" * 1000
|
||||
first_prompt = [
|
||||
shared_prompt + "Hello, my name is",
|
||||
]
|
||||
second_prompt = [
|
||||
shared_prompt + "Tell me a very long story",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
# Print the first output
|
||||
print_output(llm, first_prompt, sampling_params, "first")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# print the second output
|
||||
print_output(llm, second_prompt, sampling_params, "second")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
144
examples/others/lmcache/disagg_prefill_lmcache_v0.py
Normal file
144
examples/others/lmcache/disagg_prefill_lmcache_v0.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of disaggregated prefilling
|
||||
with LMCache.
|
||||
We will launch 2 vllm instances (GPU 0 for prefill and GPU 1 for decode),
|
||||
and launch an additional LMCache server.
|
||||
KV cache is transferred in the following manner:
|
||||
vLLM prefill node -> LMCache server -> vLLM decode node.
|
||||
|
||||
Note that `pip install lmcache` is needed to run this example.
|
||||
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
from lmcache.experimental.cache_engine import LMCacheEngineBuilder
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
# LMCache-related environment variables
|
||||
# The port to start LMCache server
|
||||
port = 8100
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
# LMCache is set to use 256 tokens per chunk
|
||||
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||
# Disable local CPU backend in LMCache
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "False"
|
||||
# Set local CPU memory buffer limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
# Set the remote URL for LMCache server
|
||||
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
|
||||
# Set the serializer/deserializer between vllm and LMCache server
|
||||
# `naive` indicates using raw bytes of the tensor without any compression
|
||||
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you?" * 1000,
|
||||
]
|
||||
|
||||
|
||||
def run_prefill(prefill_done, prompts):
|
||||
# We use GPU 0 for prefill node.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="LMCacheConnector",
|
||||
kv_role="kv_producer",
|
||||
kv_rank=0,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
# llm.generate(prompts, sampling_params)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("Prefill node is finished.")
|
||||
prefill_done.set()
|
||||
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
|
||||
def run_decode(prefill_done, prompts, timeout=1):
|
||||
# We use GPU 1 for decode node.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="LMCacheConnector",
|
||||
kv_role="kv_consumer",
|
||||
kv_rank=1,
|
||||
kv_parallel_size=2,
|
||||
)
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# of memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
time.sleep(timeout)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
|
||||
def run_lmcache_server(port):
|
||||
server_proc = subprocess.Popen(
|
||||
["python", "-m", "lmcache.experimental.server", "localhost", str(port)]
|
||||
)
|
||||
return server_proc
|
||||
|
||||
|
||||
def main():
|
||||
prefill_done = Event()
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
|
||||
lmcache_server_process = run_lmcache_server(port)
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
||||
# Start decode node
|
||||
decode_process.start()
|
||||
|
||||
# Clean up the processes
|
||||
decode_process.join()
|
||||
prefill_process.terminate()
|
||||
lmcache_server_process.terminate()
|
||||
lmcache_server_process.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,13 @@
|
||||
local_cpu: False
|
||||
max_local_cpu_size: 0
|
||||
#local_disk:
|
||||
max_local_disk_size: 0
|
||||
remote_serde: NULL
|
||||
|
||||
enable_nixl: True
|
||||
nixl_role: "receiver"
|
||||
nixl_peer_host: "localhost"
|
||||
nixl_peer_port: 55555
|
||||
nixl_buffer_size: 1073741824 # 1GB
|
||||
nixl_buffer_device: "cuda"
|
||||
nixl_enable_gc: True
|
||||
@@ -0,0 +1,13 @@
|
||||
local_cpu: False
|
||||
max_local_cpu_size: 0
|
||||
#local_disk:
|
||||
max_local_disk_size: 0
|
||||
remote_serde: NULL
|
||||
|
||||
enable_nixl: True
|
||||
nixl_role: "sender"
|
||||
nixl_peer_host: "localhost"
|
||||
nixl_peer_port: 55555
|
||||
nixl_buffer_size: 1073741824 # 1GB
|
||||
nixl_buffer_device: "cuda"
|
||||
nixl_enable_gc: True
|
||||
@@ -0,0 +1,142 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "Warning: LMCache disaggregated prefill support for vLLM v1 is experimental and subject to change."
|
||||
|
||||
|
||||
PIDS=()
|
||||
|
||||
# Switch to the directory of the current script
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
check_hf_token() {
|
||||
if [ -z "$HF_TOKEN" ]; then
|
||||
echo "HF_TOKEN is not set. Please set it to your Hugging Face token."
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$HF_TOKEN" != hf_* ]]; then
|
||||
echo "HF_TOKEN is not a valid Hugging Face token. Please set it to your Hugging Face token."
|
||||
exit 1
|
||||
fi
|
||||
echo "HF_TOKEN is set and valid."
|
||||
}
|
||||
|
||||
check_num_gpus() {
|
||||
# can you check if the number of GPUs are >=2 via nvidia-smi/rocm-smi?
|
||||
which rocm-smi > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
else
|
||||
num_gpus=$(rocm-smi --showid | grep Instinct | wc -l)
|
||||
fi
|
||||
|
||||
if [ "$num_gpus" -lt 2 ]; then
|
||||
echo "You need at least 2 GPUs to run disaggregated prefill."
|
||||
exit 1
|
||||
else
|
||||
echo "Found $num_gpus GPUs."
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_python_library_installed() {
|
||||
echo "Checking if $1 is installed..."
|
||||
python3 -c "import $1" > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
if [ "$1" == "nixl" ]; then
|
||||
echo "$1 is not installed. Please refer to https://github.com/ai-dynamo/nixl for installation."
|
||||
else
|
||||
echo "$1 is not installed. Please install it via pip install $1."
|
||||
fi
|
||||
exit 1
|
||||
else
|
||||
echo "$1 is installed."
|
||||
fi
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "Stopping everything…"
|
||||
trap - INT TERM # prevent re-entrancy
|
||||
kill -- -$$ # negative PID == “this whole process-group”
|
||||
wait # reap children so we don't leave zombies
|
||||
exit 0
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
local timeout_seconds=1200
|
||||
local start_time=$(date +%s)
|
||||
|
||||
echo "Waiting for server on port $port..."
|
||||
|
||||
while true; do
|
||||
if curl -s "localhost:${port}/v1/completions" > /dev/null; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
local now=$(date +%s)
|
||||
if (( now - start_time >= timeout_seconds )); then
|
||||
echo "Timeout waiting for server"
|
||||
return 1
|
||||
fi
|
||||
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
main() {
|
||||
check_hf_token
|
||||
check_num_gpus
|
||||
ensure_python_library_installed lmcache
|
||||
ensure_python_library_installed nixl
|
||||
ensure_python_library_installed pandas
|
||||
ensure_python_library_installed datasets
|
||||
ensure_python_library_installed vllm
|
||||
|
||||
trap cleanup INT
|
||||
trap cleanup USR1
|
||||
trap cleanup TERM
|
||||
|
||||
echo "Launching prefiller, decoder and proxy..."
|
||||
echo "Please check prefiller.log, decoder.log and proxy.log for logs."
|
||||
|
||||
bash disagg_vllm_launcher.sh prefiller \
|
||||
> >(tee prefiller.log) 2>&1 &
|
||||
prefiller_pid=$!
|
||||
PIDS+=($prefiller_pid)
|
||||
|
||||
bash disagg_vllm_launcher.sh decoder \
|
||||
> >(tee decoder.log) 2>&1 &
|
||||
decoder_pid=$!
|
||||
PIDS+=($decoder_pid)
|
||||
|
||||
python3 disagg_proxy_server.py \
|
||||
--host localhost \
|
||||
--port 9000 \
|
||||
--prefiller-host localhost \
|
||||
--prefiller-port 8100 \
|
||||
--decoder-host localhost \
|
||||
--decoder-port 8200 \
|
||||
> >(tee proxy.log) 2>&1 &
|
||||
proxy_pid=$!
|
||||
PIDS+=($proxy_pid)
|
||||
|
||||
wait_for_server 8100
|
||||
wait_for_server 8200
|
||||
wait_for_server 9000
|
||||
|
||||
echo "All servers are up. Starting benchmark..."
|
||||
|
||||
# begin benchmark
|
||||
cd ../../../../benchmarks/
|
||||
vllm bench serve --port 9000 --seed $(date +%s) \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--dataset-name random --random-input-len 7500 --random-output-len 200 \
|
||||
--num-prompts 200 --burstiness 100 --request-rate 3.6 | tee benchmark.log
|
||||
|
||||
echo "Benchmarking done. Cleaning up..."
|
||||
|
||||
cleanup
|
||||
|
||||
}
|
||||
|
||||
main
|
||||
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize clients
|
||||
prefiller_base_url = (
|
||||
f"http://{global_args.prefiller_host}:{global_args.prefiller_port}/v1"
|
||||
)
|
||||
decoder_base_url = (
|
||||
f"http://{global_args.decoder_host}:{global_args.decoder_port}/v1"
|
||||
)
|
||||
|
||||
app.state.prefill_client = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=prefiller_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
)
|
||||
app.state.decode_client = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=decoder_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close clients
|
||||
await app.state.prefill_client.aclose()
|
||||
await app.state.decode_client.aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class StatsCalculator:
|
||||
def __init__(self):
|
||||
self._stats = []
|
||||
self._last_log_time = time.time()
|
||||
|
||||
def add(self, value):
|
||||
self._stats.append(value)
|
||||
if time.time() - self._last_log_time > 5:
|
||||
self._log_stats()
|
||||
self._last_log_time = time.time()
|
||||
|
||||
def _log_stats(self):
|
||||
# Print average, median, and 99th percentile
|
||||
np_arr = np.array(self._stats)
|
||||
output_str = (
|
||||
f"\nNum requests: {len(self._stats)}"
|
||||
+ "\nPrefill node TTFT stats:"
|
||||
+ f"\n - Average (ms): {np.mean(np_arr)}"
|
||||
+ f"\n - Median (ms): {np.median(np_arr)}"
|
||||
+ f"\n - 99th Percentile (ms): {np.percentile(np_arr, 99)}\n"
|
||||
)
|
||||
print(
|
||||
"===============================",
|
||||
output_str,
|
||||
"===============================",
|
||||
)
|
||||
|
||||
|
||||
stats_calculator = StatsCalculator()
|
||||
counter = 0
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-port", type=int, default=8100)
|
||||
parser.add_argument("--decoder-host", type=str, default="localhost")
|
||||
parser.add_argument("--decoder-port", type=int, default=8200)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
# Initialize variables to hold the persistent clients
|
||||
app.state.prefill_client = None
|
||||
app.state.decode_client = None
|
||||
|
||||
|
||||
async def send_request_to_service(
|
||||
client: httpx.AsyncClient, endpoint: str, req_data: dict
|
||||
):
|
||||
"""
|
||||
Send a request to a service using a persistent client.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# read/consume the response body to release the connection
|
||||
# otherwise, it would http.ReadError
|
||||
await response.aread()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(
|
||||
client: httpx.AsyncClient, endpoint: str, req_data: dict
|
||||
):
|
||||
"""
|
||||
Asynchronously stream the response from a service using a persistent client.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
async with client.stream(
|
||||
"POST", endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
global counter, stats_calculator
|
||||
counter += 1
|
||||
|
||||
st = time.time()
|
||||
try:
|
||||
req_data = await request.json()
|
||||
|
||||
# Send request to prefill service, ignore the response
|
||||
await send_request_to_service(
|
||||
app.state.prefill_client, "/completions", req_data
|
||||
)
|
||||
|
||||
et = time.time()
|
||||
stats_calculator.add(et - st)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(
|
||||
app.state.decode_client, "/completions", req_data
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server - completions endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
global counter, stats_calculator
|
||||
counter += 1
|
||||
|
||||
st = time.time()
|
||||
try:
|
||||
req_data = await request.json()
|
||||
|
||||
# Send request to prefill service, ignore the response
|
||||
await send_request_to_service(
|
||||
app.state.prefill_client, "/chat/completions", req_data
|
||||
)
|
||||
|
||||
et = time.time()
|
||||
stats_calculator.add(et - st)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(
|
||||
app.state.decode_client, "/chat/completions", req_data
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print(
|
||||
"Error occurred in disagg prefill proxy server - chat completions endpoint"
|
||||
)
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
@@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
if [[ $# -lt 1 ]]; then
|
||||
echo "Usage: $0 <prefiller | decoder> [model]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ $# -eq 1 ]]; then
|
||||
echo "Using default model: meta-llama/Llama-3.1-8B-Instruct"
|
||||
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||
else
|
||||
echo "Using model: $2"
|
||||
MODEL=$2
|
||||
fi
|
||||
|
||||
# The prefillers and decoders in LMCache use the same hash seed for all chunk keys.
|
||||
# This seed must be aligned so that decoders can identify and retrieve KV cache
|
||||
# entries stored by prefillers.
|
||||
#
|
||||
# WARNING: Using a fixed hash seed is insecure and makes the application vulnerable to
|
||||
# denial-of-service attacks. In a production environment, this should be set to a
|
||||
# secure random value. This is set to a fixed value for demonstration purposes only.
|
||||
export PYTHONHASHSEED=${VLLM_PYTHON_HASH_SEED:-123}
|
||||
|
||||
if [[ $1 == "prefiller" ]]; then
|
||||
# Prefiller listens on port 8100
|
||||
prefill_config_file=$SCRIPT_DIR/configs/lmcache-prefiller-config.yaml
|
||||
|
||||
UCX_TLS=cuda_ipc,cuda_copy,tcp \
|
||||
LMCACHE_CONFIG_FILE=$prefill_config_file \
|
||||
LMCACHE_USE_EXPERIMENTAL=True \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
vllm serve $MODEL \
|
||||
--port 8100 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_producer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "producer1"}}'
|
||||
|
||||
|
||||
elif [[ $1 == "decoder" ]]; then
|
||||
# Decoder listens on port 8200
|
||||
decode_config_file=$SCRIPT_DIR/configs/lmcache-decoder-config.yaml
|
||||
|
||||
UCX_TLS=cuda_ipc,cuda_copy,tcp \
|
||||
LMCACHE_CONFIG_FILE=$decode_config_file \
|
||||
LMCACHE_USE_EXPERIMENTAL=True \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=1 \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
CUDA_VISIBLE_DEVICES=1 \
|
||||
vllm serve $MODEL \
|
||||
--port 8200 \
|
||||
--enforce-eager \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_consumer","kv_connector_extra_config": {"discard_partial_chunks": false, "lmcache_rpc_port": "consumer1"}}'
|
||||
|
||||
|
||||
else
|
||||
echo "Invalid role: $1"
|
||||
echo "Should be either prefiller, decoder"
|
||||
exit 1
|
||||
fi
|
||||
133
examples/others/lmcache/kv_cache_sharing_lmcache_v1.py
Normal file
133
examples/others/lmcache/kv_cache_sharing_lmcache_v1.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of remote KV cache sharing
|
||||
with LMCache.
|
||||
We will launch 2 vllm instances, and launch an additional LMCache server.
|
||||
KV cache is transferred in the following manner:
|
||||
(1) vLLM instance 1 -> LMCache server (KV cache store).
|
||||
(2) LMCache server -> vLLM instance 2 (KV cache reuse/retrieve).
|
||||
|
||||
Note that lmcache needs to be installed to run this example.
|
||||
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from lmcache.v1.cache_engine import LMCacheEngineBuilder
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
# LMCache-related environment variables
|
||||
# The port to start LMCache server
|
||||
port = 8100
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
# LMCache is set to use 256 tokens per chunk
|
||||
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||
# Disable local CPU backend in LMCache
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "False"
|
||||
# Set local CPU memory buffer limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
# Set the remote URL for LMCache server
|
||||
os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
|
||||
# Set the serializer/deserializer between vllm and LMCache server
|
||||
# `naive` indicates using raw bytes of the tensor without any compression
|
||||
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you?" * 1000,
|
||||
]
|
||||
|
||||
|
||||
def run_store(store_done, prompts):
|
||||
# We use GPU 0 for KV cache store process.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("KV cache store is finished.")
|
||||
store_done.set()
|
||||
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
|
||||
def run_retrieve(store_done, prompts, timeout=1):
|
||||
# We use GPU 1 for KV cache retrieve process.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", kv_role="kv_both")
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# of memory. Reduce the value if your GPU has less memory.
|
||||
llm = LLM(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
gpu_memory_utilization=0.8,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
print("Waiting for KV cache store to finish...")
|
||||
store_done.wait()
|
||||
time.sleep(timeout)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
|
||||
def run_lmcache_server(port):
|
||||
server_proc = subprocess.Popen(
|
||||
["python", "-m", "lmcache.v1.server", "localhost", str(port)]
|
||||
)
|
||||
return server_proc
|
||||
|
||||
|
||||
def main():
|
||||
store_done = Event()
|
||||
store_process = Process(target=run_store, args=(store_done, prompts))
|
||||
retrieve_process = Process(target=run_retrieve, args=(store_done, prompts))
|
||||
lmcache_server_process = run_lmcache_server(port)
|
||||
|
||||
# Start KV cache store process
|
||||
store_process.start()
|
||||
|
||||
# Start KV cache retrieve process
|
||||
retrieve_process.start()
|
||||
|
||||
# Clean up the processes
|
||||
store_process.join()
|
||||
retrieve_process.terminate()
|
||||
lmcache_server_process.terminate()
|
||||
lmcache_server_process.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
162
examples/others/logging_configuration.md
Normal file
162
examples/others/logging_configuration.md
Normal file
@@ -0,0 +1,162 @@
|
||||
# Logging Configuration
|
||||
|
||||
vLLM leverages Python's `logging.config.dictConfig` functionality to enable
|
||||
robust and flexible configuration of the various loggers used by vLLM.
|
||||
|
||||
vLLM offers two environment variables that can be used to accommodate a range
|
||||
of logging configurations that range from simple-and-inflexible to
|
||||
more-complex-and-more-flexible.
|
||||
|
||||
- No vLLM logging (simple and inflexible)
|
||||
- Set `VLLM_CONFIGURE_LOGGING=0` (leaving `VLLM_LOGGING_CONFIG_PATH` unset)
|
||||
- vLLM's default logging configuration (simple and inflexible)
|
||||
- Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1`
|
||||
- Fine-grained custom logging configuration (more complex, more flexible)
|
||||
- Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and
|
||||
set `VLLM_LOGGING_CONFIG_PATH=<path-to-logging-config.json>`
|
||||
|
||||
## Logging Configuration Environment Variables
|
||||
|
||||
### `VLLM_CONFIGURE_LOGGING`
|
||||
|
||||
`VLLM_CONFIGURE_LOGGING` controls whether or not vLLM takes any action to
|
||||
configure the loggers used by vLLM. This functionality is enabled by default,
|
||||
but can be disabled by setting `VLLM_CONFIGURE_LOGGING=0` when running vLLM.
|
||||
|
||||
If `VLLM_CONFIGURE_LOGGING` is enabled and no value is given for
|
||||
`VLLM_LOGGING_CONFIG_PATH`, vLLM will use built-in default configuration to
|
||||
configure the root vLLM logger. By default, no other vLLM loggers are
|
||||
configured and, as such, all vLLM loggers defer to the root vLLM logger to make
|
||||
all logging decisions.
|
||||
|
||||
If `VLLM_CONFIGURE_LOGGING` is disabled and a value is given for
|
||||
`VLLM_LOGGING_CONFIG_PATH`, an error will occur while starting vLLM.
|
||||
|
||||
### `VLLM_LOGGING_CONFIG_PATH`
|
||||
|
||||
`VLLM_LOGGING_CONFIG_PATH` allows users to specify a path to a JSON file of
|
||||
alternative, custom logging configuration that will be used instead of vLLM's
|
||||
built-in default logging configuration. The logging configuration should be
|
||||
provided in JSON format following the schema specified by Python's [logging
|
||||
configuration dictionary
|
||||
schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details).
|
||||
|
||||
If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is
|
||||
disabled, an error will occur while starting vLLM.
|
||||
|
||||
## Examples
|
||||
|
||||
### Example 1: Customize vLLM root logger
|
||||
|
||||
For this example, we will customize the vLLM root logger to use
|
||||
[`python-json-logger`](https://github.com/nhairs/python-json-logger)
|
||||
(which is part of the container image) to log to
|
||||
STDOUT of the console in JSON format with a log level of `INFO`.
|
||||
|
||||
To begin, first, create an appropriate JSON logging configuration file:
|
||||
|
||||
??? note "/path/to/logging_config.json"
|
||||
|
||||
```json
|
||||
{
|
||||
"formatters": {
|
||||
"json": {
|
||||
"class": "pythonjsonlogger.jsonlogger.JsonFormatter"
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class" : "logging.StreamHandler",
|
||||
"formatter": "json",
|
||||
"level": "INFO",
|
||||
"stream": "ext://sys.stdout"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"vllm": {
|
||||
"handlers": ["console"],
|
||||
"level": "INFO",
|
||||
"propagate": false
|
||||
}
|
||||
},
|
||||
"version": 1
|
||||
}
|
||||
```
|
||||
|
||||
Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set
|
||||
to the path of the custom logging configuration JSON file:
|
||||
|
||||
```bash
|
||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||
```
|
||||
|
||||
### Example 2: Silence a particular vLLM logger
|
||||
|
||||
To silence a particular vLLM logger, it is necessary to provide custom logging
|
||||
configuration for the target logger that configures the logger so that it won't
|
||||
propagate its log messages to the root vLLM logger.
|
||||
|
||||
When custom configuration is provided for any logger, it is also necessary to
|
||||
provide configuration for the root vLLM logger since any custom logger
|
||||
configuration overrides the built-in default logging configuration used by vLLM.
|
||||
|
||||
First, create an appropriate JSON logging configuration file that includes
|
||||
configuration for the root vLLM logger and for the logger you wish to silence:
|
||||
|
||||
??? note "/path/to/logging_config.json"
|
||||
|
||||
```json
|
||||
{
|
||||
"formatters": {
|
||||
"vllm": {
|
||||
"class": "vllm.logging_utils.NewLineFormatter",
|
||||
"datefmt": "%m-%d %H:%M:%S",
|
||||
"format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
}
|
||||
},
|
||||
"handlers": {
|
||||
"vllm": {
|
||||
"class" : "logging.StreamHandler",
|
||||
"formatter": "vllm",
|
||||
"level": "INFO",
|
||||
"stream": "ext://sys.stdout"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"vllm": {
|
||||
"handlers": ["vllm"],
|
||||
"level": "DEBUG",
|
||||
"propagate": false
|
||||
},
|
||||
"vllm.example_noisy_logger": {
|
||||
"propagate": false
|
||||
}
|
||||
},
|
||||
"version": 1
|
||||
}
|
||||
```
|
||||
|
||||
Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set
|
||||
to the path of the custom logging configuration JSON file:
|
||||
|
||||
```bash
|
||||
VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \
|
||||
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||
```
|
||||
|
||||
### Example 3: Disable vLLM default logging configuration
|
||||
|
||||
To disable vLLM's default logging configuration and silence all vLLM loggers,
|
||||
simple set `VLLM_CONFIGURE_LOGGING=0` when running vLLM. This will prevent vLLM
|
||||
for configuring the root vLLM logger, which in turn, silences all other vLLM
|
||||
loggers.
|
||||
|
||||
```bash
|
||||
VLLM_CONFIGURE_LOGGING=0 \
|
||||
vllm serve mistralai/Mistral-7B-v0.1 --max-model-len 2048
|
||||
```
|
||||
|
||||
## Additional resources
|
||||
|
||||
- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details)
|
||||
392
examples/others/tensorize_vllm_model.py
Normal file
392
examples/others/tensorize_vllm_model.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerArgs,
|
||||
TensorizerConfig,
|
||||
tensorize_lora_adapter,
|
||||
tensorize_vllm_model,
|
||||
tensorizer_kwargs_arg,
|
||||
)
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
"""
|
||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||
deserialize vLLM models. These models can be loaded using tensorizer
|
||||
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||
or locally. Tensor encryption and decryption is also supported, although
|
||||
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||
using `pip install vllm[tensorizer]`. To learn more about tensorizer, visit
|
||||
https://github.com/coreweave/tensorizer
|
||||
|
||||
To serialize a model, install vLLM from source, then run something
|
||||
like this from the root level of this repository:
|
||||
|
||||
python examples/others/tensorize_vllm_model.py \
|
||||
--model facebook/opt-125m \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket \
|
||||
--suffix v1
|
||||
|
||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||
and saves it to your S3 bucket. A local directory can also be used. This
|
||||
assumes your S3 credentials are specified as environment variables
|
||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and
|
||||
`S3_ENDPOINT_URL`. To provide S3 credentials directly, you can provide
|
||||
`--s3-access-key-id` and `--s3-secret-access-key`, as well as `--s3-endpoint`
|
||||
as CLI args to this script.
|
||||
|
||||
You can also encrypt the model weights with a randomly-generated key by
|
||||
providing a `--keyfile` argument.
|
||||
|
||||
To deserialize a model, you can run something like this from the root
|
||||
level of this repository:
|
||||
|
||||
python examples/others/tensorize_vllm_model.py \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors
|
||||
|
||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||
|
||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||
they were serialized with encryption.
|
||||
|
||||
To support distributed tensor-parallel models, each model shard will be
|
||||
serialized to a separate file. The tensorizer_uri is then specified as a string
|
||||
template with a format specifier such as '%03d' that will be rendered with the
|
||||
shard's rank. Sharded models serialized with this script will be named as
|
||||
model-rank-%03d.tensors
|
||||
|
||||
For more information on the available arguments for serializing, run
|
||||
`python -m examples.others.tensorize_vllm_model serialize --help`.
|
||||
|
||||
Or for deserializing:
|
||||
|
||||
`python examples/others/tensorize_vllm_model.py deserialize --help`.
|
||||
|
||||
Once a model is serialized, tensorizer can be invoked with the `LLM` class
|
||||
directly to load models:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(
|
||||
"s3://my-bucket/vllm/facebook/opt-125m/v1",
|
||||
load_format="tensorizer",
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
A serialized model can be used during model loading for the vLLM OpenAI
|
||||
inference server:
|
||||
|
||||
```
|
||||
vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 \
|
||||
--load-format tensorizer
|
||||
```
|
||||
|
||||
In order to see all of the available arguments usable to configure
|
||||
loading with tensorizer that are given to `TensorizerConfig`, run:
|
||||
|
||||
`python examples/others/tensorize_vllm_model.py deserialize --help`
|
||||
|
||||
under the `tensorizer options` section. These can also be used for
|
||||
deserialization in this example script, although `--tensorizer-uri` and
|
||||
`--path-to-tensors` are functionally the same in this case.
|
||||
|
||||
Tensorizer can also be used to save and load LoRA adapters. A LoRA adapter
|
||||
can be serialized directly with the path to the LoRA adapter on HF Hub and
|
||||
a TensorizerConfig object. In this script, passing a HF id to a LoRA adapter
|
||||
will serialize the LoRA adapter artifacts to `--serialized-directory`.
|
||||
|
||||
You can then use the LoRA adapter with `vllm serve`, for instance, by ensuring
|
||||
the LoRA artifacts are in your model artifacts directory and specifying
|
||||
`--enable-lora`. For instance:
|
||||
|
||||
```
|
||||
vllm serve s3://my-bucket/vllm/facebook/opt-125m/v1 \
|
||||
--load-format tensorizer \
|
||||
--enable-lora
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it."
|
||||
)
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to a LoRA adapter to "
|
||||
"serialize along with model tensors. This can then be deserialized "
|
||||
"along with the model by instantiating a TensorizerConfig object, "
|
||||
"creating a dict from it with TensorizerConfig.to_serializable(), "
|
||||
"and passing it to LoRARequest's initializer with the kwarg "
|
||||
"tensorizer_config_dict.",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
"serialize", help="Serialize a model to `--serialized-directory`"
|
||||
)
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"The suffix to append to the serialized model directory, which is "
|
||||
"used to construct the location of the serialized model tensors, "
|
||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||
"saved to "
|
||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||
"If none is provided, a random UUID will be used."
|
||||
),
|
||||
)
|
||||
serialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The directory to serialize the model to. "
|
||||
"This can be a local directory or S3 URI. The path to where the "
|
||||
"tensors are saved is a combination of the supplied `dir` and model "
|
||||
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||
"provided.",
|
||||
)
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--serialization-kwargs",
|
||||
type=tensorizer_kwargs_arg,
|
||||
required=False,
|
||||
help=(
|
||||
"A JSON string containing additional keyword arguments to "
|
||||
"pass to Tensorizer's TensorSerializer during "
|
||||
"serialization."
|
||||
),
|
||||
)
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"Encrypt the model weights with a randomly-generated binary key,"
|
||||
" and save the key at this path"
|
||||
),
|
||||
)
|
||||
|
||||
deserialize_parser = subparsers.add_parser(
|
||||
"deserialize",
|
||||
help=(
|
||||
"Deserialize a model from `--path-to-tensors`"
|
||||
" to verify it can be loaded and used."
|
||||
),
|
||||
)
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ",
|
||||
)
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Directory with model artifacts for loading. Assumes a "
|
||||
"model.tensors file exists therein. Can supersede "
|
||||
"--path-to-tensors.",
|
||||
)
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"
|
||||
),
|
||||
)
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--deserialization-kwargs",
|
||||
type=tensorizer_kwargs_arg,
|
||||
required=False,
|
||||
help=(
|
||||
"A JSON string containing additional keyword arguments to "
|
||||
"pass to Tensorizer's `TensorDeserializer` during "
|
||||
"deserialization."
|
||||
),
|
||||
)
|
||||
|
||||
TensorizerArgs.add_cli_args(deserialize_parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def merge_extra_config_with_tensorizer_config(extra_cfg: dict, cfg: TensorizerConfig):
|
||||
for k, v in extra_cfg.items():
|
||||
if hasattr(cfg, k):
|
||||
setattr(cfg, k, v)
|
||||
logger.info(
|
||||
"Updating TensorizerConfig with %s from "
|
||||
"--model-loader-extra-config provided",
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def deserialize(args, tensorizer_config):
|
||||
if args.lora_path:
|
||||
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
load_format="tensorizer",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
model_loader_extra_config=tensorizer_config,
|
||||
enable_lora=True,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, max_tokens=256, stop=["[/assistant]"]
|
||||
)
|
||||
|
||||
# Truncating this as the extra text isn't necessary
|
||||
prompts = ["[user] Write a SQL query to answer the question based on ..."]
|
||||
|
||||
# Test LoRA load
|
||||
print(
|
||||
llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(
|
||||
"sql-lora",
|
||||
1,
|
||||
args.lora_path,
|
||||
tensorizer_config_dict=tensorizer_config.to_serializable(),
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
load_format="tensorizer",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
model_loader_extra_config=tensorizer_config,
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
s3_access_key_id = getattr(args, "s3_access_key_id", None) or os.environ.get(
|
||||
"S3_ACCESS_KEY_ID", None
|
||||
)
|
||||
s3_secret_access_key = getattr(
|
||||
args, "s3_secret_access_key", None
|
||||
) or os.environ.get("S3_SECRET_ACCESS_KEY", None)
|
||||
s3_endpoint = getattr(args, "s3_endpoint", None) or os.environ.get(
|
||||
"S3_ENDPOINT_URL", None
|
||||
)
|
||||
|
||||
credentials = {
|
||||
"s3_access_key_id": s3_access_key_id,
|
||||
"s3_secret_access_key": s3_secret_access_key,
|
||||
"s3_endpoint": s3_endpoint,
|
||||
}
|
||||
|
||||
model_ref = args.model
|
||||
|
||||
if args.command == "serialize" or args.command == "deserialize":
|
||||
keyfile = args.keyfile
|
||||
else:
|
||||
keyfile = None
|
||||
|
||||
extra_config = {}
|
||||
if args.model_loader_extra_config:
|
||||
extra_config = json.loads(args.model_loader_extra_config)
|
||||
|
||||
tensorizer_dir = args.serialized_directory or extra_config.get("tensorizer_dir")
|
||||
tensorizer_uri = getattr(args, "path_to_tensors", None) or extra_config.get(
|
||||
"tensorizer_uri"
|
||||
)
|
||||
|
||||
if tensorizer_dir and tensorizer_uri:
|
||||
parser.error(
|
||||
"--serialized-directory and --path-to-tensors cannot both be provided"
|
||||
)
|
||||
|
||||
if not tensorizer_dir and not tensorizer_uri:
|
||||
parser.error(
|
||||
"Either --serialized-directory or --path-to-tensors must be provided"
|
||||
)
|
||||
|
||||
if args.command == "serialize":
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
input_dir = tensorizer_dir.rstrip("/")
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
if engine_args.tensor_parallel_size > 1:
|
||||
model_path = f"{base_path}/model-rank-%03d.tensors"
|
||||
else:
|
||||
model_path = f"{base_path}/model.tensors"
|
||||
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=model_path,
|
||||
encryption_keyfile=keyfile,
|
||||
serialization_kwargs=args.serialization_kwargs or {},
|
||||
**credentials,
|
||||
)
|
||||
|
||||
if args.lora_path:
|
||||
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
|
||||
tensorize_lora_adapter(args.lora_path, tensorizer_config)
|
||||
|
||||
merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
|
||||
tensorize_vllm_model(engine_args, tensorizer_config)
|
||||
|
||||
elif args.command == "deserialize":
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=args.path_to_tensors,
|
||||
tensorizer_dir=args.serialized_directory,
|
||||
encryption_keyfile=keyfile,
|
||||
deserialization_kwargs=args.deserialization_kwargs or {},
|
||||
**credentials,
|
||||
)
|
||||
|
||||
merge_extra_config_with_tensorizer_config(extra_config, tensorizer_config)
|
||||
deserialize(args, tensorizer_config)
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user