Sync from v0.13
This commit is contained in:
@@ -1,15 +1,25 @@
|
||||
import argparse
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -26,27 +36,20 @@ def main(
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
kv_cache_dtype: str | None = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
query = torch.empty(
|
||||
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
|
||||
|
||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||
max_seq_len = max(seq_lens)
|
||||
@@ -54,30 +57,38 @@ def main(
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
if not args.custom_paged_attn and not current_platform.is_navi():
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@@ -97,7 +108,7 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
@@ -114,34 +125,58 @@ def main(
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
if not args.custom_paged_attn:
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
@@ -157,39 +192,43 @@ def main(
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
if __name__ == "__main__":
|
||||
logger.warning(
|
||||
"This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference."
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--seq_len", type=int, default=4096)
|
||||
parser.add_argument("--seq-len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user