Sync from v0.13
This commit is contained in:
160
benchmarks/kernels/bench_block_fp8_gemm.py
Normal file
160
benchmarks/kernels/bench_block_fp8_gemm.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton as vllm_triton
|
||||
|
||||
assert current_platform.is_cuda(), (
|
||||
"Only support benchmarking w8a8 block fp8 kernel on CUDA device."
|
||||
)
|
||||
|
||||
# DeepSeek-V3 weight shapes
|
||||
DEEPSEEK_V3_SHAPES = [
|
||||
(512 + 64, 7168),
|
||||
(2112, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
(18432 * 2, 7168),
|
||||
(24576, 1536),
|
||||
(12288, 7168),
|
||||
(4096, 7168),
|
||||
(7168, 2048),
|
||||
]
|
||||
|
||||
|
||||
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
"""Build runner function for w8a8 block fp8 matmul."""
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
|
||||
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
|
||||
# Create quantized weight tensor
|
||||
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Create weight scales
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
# Create W8A8BlockFp8LinearOp instance
|
||||
weight_group_shape = GroupShape(block_n, block_k)
|
||||
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||
|
||||
linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=weight_group_shape,
|
||||
act_quant_group_shape=act_quant_group_shape,
|
||||
cutlass_block_fp8_supported=use_cutlass,
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
|
||||
def run():
|
||||
return linear_op.apply(
|
||||
input=A_ref,
|
||||
weight=B,
|
||||
weight_scale=Bs,
|
||||
input_scale=None,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
# Determine available providers
|
||||
available_providers = ["torch-bf16", "w8a8-block-fp8-triton"]
|
||||
plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
||||
|
||||
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||
available_providers.append("w8a8-block-fp8-cutlass")
|
||||
|
||||
|
||||
@vllm_triton.testing.perf_report(
|
||||
vllm_triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=available_providers,
|
||||
line_names=available_providers,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
a = torch.randn((M, K), device=device, dtype=torch.bfloat16)
|
||||
b = torch.randn((N, K), device=device, dtype=torch.bfloat16)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
elif provider == "w8a8-block-fp8-triton":
|
||||
run_w8a8_triton = build_w8a8_block_fp8_runner(
|
||||
M, N, K, block_size, device, use_cutlass=False
|
||||
)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: run_w8a8_triton(), quantiles=quantiles
|
||||
)
|
||||
elif provider == "w8a8-block-fp8-cutlass":
|
||||
run_w8a8_cutlass = build_w8a8_block_fp8_runner(
|
||||
M, N, K, block_size, device, use_cutlass=True
|
||||
)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
block_size = (128, 128)
|
||||
|
||||
for N, K in DEEPSEEK_V3_SHAPES:
|
||||
print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}")
|
||||
|
||||
print(f"TFLOP/s comparison (block_size={block_size}):")
|
||||
benchmark_tflops.run(
|
||||
print_data=True,
|
||||
# show_plots=False,
|
||||
# save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
print("\nBenchmark finished!")
|
||||
159
benchmarks/kernels/bench_fp8_gemm.py
Normal file
159
benchmarks/kernels/bench_fp8_gemm.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"fp8-tensor-w-token-a": dict(
|
||||
w="tensor", a="token", no_a_quant=False, enabled=False
|
||||
),
|
||||
"fp8-tensor-w-tensor-a": dict(
|
||||
w="tensor", a="tensor", no_a_quant=False, enabled=True
|
||||
),
|
||||
"fp8-channel-w-token-a": dict(
|
||||
w="channel", a="token", no_a_quant=False, enabled=True
|
||||
),
|
||||
"fp8-channel-w-tensor-a": dict(
|
||||
w="channel", a="tensor", no_a_quant=False, enabled=False
|
||||
),
|
||||
"fp8-tensor-w-token-a-noquant": dict(
|
||||
w="tensor", a="token", no_a_quant=True, enabled=False
|
||||
),
|
||||
"fp8-tensor-w-tensor-a-noquant": dict(
|
||||
w="tensor", a="tensor", no_a_quant=True, enabled=True
|
||||
),
|
||||
"fp8-channel-w-token-a-noquant": dict(
|
||||
w="channel", a="token", no_a_quant=True, enabled=True
|
||||
),
|
||||
"fp8-channel-w-tensor-a-noquant": dict(
|
||||
w="channel", a="tensor", no_a_quant=True, enabled=False
|
||||
),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str):
|
||||
if w_type == "tensor":
|
||||
scale_b = torch.ones(1, device=device, dtype=torch.float32)
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||
else:
|
||||
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True)
|
||||
return b_fp8.t(), scale_b_fp8
|
||||
|
||||
|
||||
def build_fp8_runner(cfg, a, b, dtype, device):
|
||||
b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device)
|
||||
|
||||
scale_a_const = (
|
||||
torch.ones(1, device=device, dtype=torch.float32)
|
||||
if cfg["a"] == "tensor"
|
||||
else None
|
||||
)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
if cfg["a"] == "tensor":
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
|
||||
else:
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
|
||||
def run():
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
return run
|
||||
|
||||
if cfg["a"] == "tensor":
|
||||
|
||||
def run():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
else:
|
||||
|
||||
def run():
|
||||
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs FP8 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_fp8_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_fp8_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
169
benchmarks/kernels/bench_int8_gemm.py
Normal file
169
benchmarks/kernels/bench_int8_gemm.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"int8-tensor-w-token-a": dict(
|
||||
w="tensor", a="token", no_a_quant=False, enabled=False
|
||||
),
|
||||
"int8-tensor-w-tensor-a": dict(
|
||||
w="tensor", a="tensor", no_a_quant=False, enabled=True
|
||||
),
|
||||
"int8-channel-w-token-a": dict(
|
||||
w="channel", a="token", no_a_quant=False, enabled=True
|
||||
),
|
||||
"int8-channel-w-tensor-a": dict(
|
||||
w="channel", a="tensor", no_a_quant=False, enabled=False
|
||||
),
|
||||
"int8-tensor-w-token-a-noquant": dict(
|
||||
w="tensor", a="token", no_a_quant=True, enabled=False
|
||||
),
|
||||
"int8-tensor-w-tensor-a-noquant": dict(
|
||||
w="tensor", a="tensor", no_a_quant=True, enabled=True
|
||||
),
|
||||
"int8-channel-w-token-a-noquant": dict(
|
||||
w="channel", a="token", no_a_quant=True, enabled=True
|
||||
),
|
||||
"int8-channel-w-tensor-a-noquant": dict(
|
||||
w="channel", a="tensor", no_a_quant=True, enabled=False
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _quant_weight(b, w_type, device):
|
||||
if w_type == "tensor":
|
||||
scale_b = torch.ones(1, device=device, dtype=torch.float32)
|
||||
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
|
||||
assert scale_b_int8.numel() == 1
|
||||
else: # channel
|
||||
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
|
||||
assert scale_b_int8.numel() == b.shape[0]
|
||||
return b_int8.t(), scale_b_int8
|
||||
|
||||
|
||||
def build_int8_runner(cfg, a, b, dtype, device):
|
||||
# quant before running the kernel
|
||||
b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device)
|
||||
|
||||
scale_a_const = None
|
||||
if cfg["a"] == "tensor":
|
||||
scale_a_const = torch.ones(1, device=device, dtype=torch.float32)
|
||||
|
||||
# no quant, create activation ahead
|
||||
if cfg["no_a_quant"]:
|
||||
if cfg["a"] == "tensor":
|
||||
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const)
|
||||
else: # token
|
||||
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
|
||||
|
||||
def run_quant():
|
||||
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||
|
||||
return run_quant
|
||||
|
||||
# dynamic quant, create activation inside
|
||||
if cfg["a"] == "tensor":
|
||||
|
||||
def run_quant():
|
||||
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const)
|
||||
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||
|
||||
else: # token
|
||||
|
||||
def run_quant():
|
||||
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
|
||||
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||
|
||||
return run_quant
|
||||
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")]
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=[k for k in _enabled],
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs INT8 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_int8_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
KN_model_names = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
KN_model_names.append(KN)
|
||||
return KN_model_names
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
help="List of models to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[1],
|
||||
help="List of tensor parallel sizes",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_int8_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"mxfp4": dict(no_a_quant=False, enabled=True),
|
||||
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _quant_weight_mxfp4(
|
||||
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
|
||||
):
|
||||
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
|
||||
b, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
|
||||
return weight_hf_e2m1, weight_hf_scale_block
|
||||
|
||||
|
||||
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
|
||||
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
|
||||
b, forward_hadamard_matrix, device
|
||||
)
|
||||
alpha = torch.tensor([1.0], device="cuda")
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||
a, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||
|
||||
def run():
|
||||
return matmul_mxf4_bf16_tn(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||
a, forward_hadamard_matrix, method="abs_max"
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||
return matmul_mxf4_bf16_tn(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
24576,
|
||||
32768,
|
||||
],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs MXFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, had_size):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_mxfp4_runner(
|
||||
cfg, a, b, forward_hadamard_matrix, dtype, device
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
for had_size in [32, 64, 128]:
|
||||
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_mxfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
had_size=had_size,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
198
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
198
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
|
||||
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_needs_fbgemm = any(
|
||||
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
|
||||
)
|
||||
if _needs_fbgemm:
|
||||
try:
|
||||
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
||||
triton_scale_nvfp4_quant,
|
||||
)
|
||||
except ImportError:
|
||||
print(
|
||||
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
|
||||
"These providers will be skipped. Please install fbgemm_gpu with: "
|
||||
"'pip install fbgemm-gpu-genai' to run them."
|
||||
)
|
||||
# Disable FBGEMM providers so the benchmark can run.
|
||||
for cfg in PROVIDER_CFGS.values():
|
||||
if cfg.get("fbgemm"):
|
||||
cfg["enabled"] = False
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
|
||||
else:
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||
if cfg["no_a_quant"]:
|
||||
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return torch.ops.fbgemm.f4f4bf16(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
scale_a_fp4,
|
||||
scale_b_fp4,
|
||||
global_scale=alpha,
|
||||
use_mx=False,
|
||||
)
|
||||
|
||||
return run
|
||||
else:
|
||||
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||
return torch.ops.fbgemm.f4f4bf16(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
scale_a_fp4,
|
||||
scale_b_fp4,
|
||||
global_scale=alpha,
|
||||
use_mx=False,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=save_dir,
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||
from vllm._custom_ops import fusedQuantizeNv
|
||||
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||
return (
|
||||
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||
* group_size**-0.5
|
||||
)
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(
|
||||
b: torch.Tensor,
|
||||
forward_hadamard_matrix: torch.Tensor,
|
||||
global_scale: torch.Tensor,
|
||||
device: str,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
):
|
||||
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
|
||||
b, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
return weight_hf_e2m1, weight_hf_scale_block
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
|
||||
alpha = torch.tensor([1.0], device="cuda")
|
||||
global_scale = torch.tensor([1.0], device="cuda")
|
||||
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
|
||||
b, forward_hadamard_matrix, global_scale, device, M, N, K
|
||||
)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||
a, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||
a, forward_hadamard_matrix, global_scale
|
||||
)
|
||||
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||
-1, K // 16
|
||||
)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
input_hf_e2m1,
|
||||
weight_hf_e2m1,
|
||||
input_hf_scale_block,
|
||||
weight_hf_scale_block,
|
||||
alpha,
|
||||
torch.bfloat16,
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
2048,
|
||||
4096,
|
||||
8192,
|
||||
16384,
|
||||
24576,
|
||||
32768,
|
||||
],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K, had_size):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(
|
||||
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
|
||||
)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
for had_size in [16, 32, 64, 128]:
|
||||
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
had_size=had_size,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
||||
270
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
270
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
def with_triton_mode(fn):
|
||||
"""Temporarily force the Triton fallback path"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def bench_compile(fn: Callable):
|
||||
# recompile for different shapes
|
||||
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
return with_dyn_arg(fwd, 0, 0)
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
|
||||
|
||||
def calculate_diff(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
group_shape: GroupShape,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Calculate the difference between Inductor and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
|
||||
|
||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
||||
|
||||
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
|
||||
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
||||
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
||||
|
||||
try:
|
||||
torch.testing.assert_close(
|
||||
cuda_out.to(torch.float32),
|
||||
torch_out.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
)
|
||||
torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
|
||||
torch.testing.assert_close(
|
||||
cuda_out.to(torch.float32),
|
||||
torch_eager_out.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
)
|
||||
torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
|
||||
print("✅ All implementations match")
|
||||
except AssertionError as e:
|
||||
print("❌ Implementations differ")
|
||||
print(e)
|
||||
|
||||
|
||||
configs = []
|
||||
|
||||
|
||||
def benchmark_quantization(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
provider,
|
||||
group_shape: GroupShape,
|
||||
col_major: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: quant_fp8.forward_cuda(x.clone())
|
||||
elif provider == "triton":
|
||||
if not group_shape.is_per_group():
|
||||
# Triton only supported for per-group
|
||||
return 0, 0, 0
|
||||
|
||||
fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
# TODO(luka) extract to utils
|
||||
def compute_geomean_speedups(
|
||||
df: pd.DataFrame,
|
||||
baseline_col: str,
|
||||
speedup_cols: list[str],
|
||||
groupby_cols: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Compute geometric mean speedups over a baseline column.
|
||||
|
||||
Args:
|
||||
df: Input dataframe
|
||||
baseline_col: Column to use as baseline
|
||||
speedup_cols: Columns to compute speedups for
|
||||
groupby_cols: Columns to group by. If None, compute over entire df.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with geometric mean speedups
|
||||
"""
|
||||
from scipy.stats import gmean
|
||||
|
||||
def geo_speedup(group: pd.DataFrame) -> pd.Series:
|
||||
ratios = {
|
||||
col: (group[baseline_col] / group[col]).values for col in speedup_cols
|
||||
}
|
||||
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})
|
||||
|
||||
if groupby_cols is None:
|
||||
result = geo_speedup(df).to_frame().T
|
||||
else:
|
||||
result = (
|
||||
df.groupby(groupby_cols)
|
||||
.apply(geo_speedup, include_groups=False)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
|
||||
)
|
||||
parser.add_argument("-c", "--check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[896, 1024, 2048, 4096, 7168],
|
||||
help="Hidden sizes to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 16, 128, 512, 1024],
|
||||
help="Batch sizes to benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Group sizes for GroupShape(1,N) to benchmark. "
|
||||
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-column-major",
|
||||
action="store_true",
|
||||
help="Disable column-major scales testing",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
assert args
|
||||
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||
|
||||
hidden_sizes = args.hidden_sizes
|
||||
batch_sizes = args.batch_sizes
|
||||
|
||||
if args.group_sizes is not None:
|
||||
group_shapes = []
|
||||
for size in args.group_sizes:
|
||||
if size == 0:
|
||||
group_shapes.append(GroupShape.PER_TENSOR)
|
||||
elif size == -1:
|
||||
group_shapes.append(GroupShape.PER_TOKEN)
|
||||
else:
|
||||
group_shapes.append(GroupShape(1, size))
|
||||
else:
|
||||
group_shapes = [
|
||||
GroupShape.PER_TENSOR,
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape(1, 64),
|
||||
GroupShape(1, 128),
|
||||
]
|
||||
|
||||
column_major_scales = [False] if args.no_column_major else [True, False]
|
||||
|
||||
config_gen = itertools.product(
|
||||
group_shapes,
|
||||
column_major_scales,
|
||||
batch_sizes,
|
||||
hidden_sizes,
|
||||
)
|
||||
|
||||
# filter out column-major scales for non-group, reverse order
|
||||
configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))
|
||||
|
||||
print(f"Running {len(configs)} configurations:")
|
||||
print(f" Hidden sizes: {hidden_sizes}")
|
||||
print(f" Batch sizes: {batch_sizes}")
|
||||
print(f" Group shapes: {[str(g) for g in group_shapes]}")
|
||||
print(f" Column major scales: {column_major_scales}")
|
||||
print()
|
||||
|
||||
if args.check:
|
||||
for group_shape in group_shapes:
|
||||
group_size = group_shape[1]
|
||||
print(f"{group_size=}")
|
||||
calculate_diff(
|
||||
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
|
||||
)
|
||||
|
||||
benchmark = triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda", "triton"],
|
||||
line_names=["Torch (Compiled)", "CUDA", "Triton"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
|
||||
ylabel="us",
|
||||
plot_name="QuantFP8 performance",
|
||||
args={},
|
||||
)
|
||||
)(benchmark_quantization)
|
||||
|
||||
df = benchmark.run(print_data=True, dtype=dtype, return_df=True)
|
||||
|
||||
# Print geomean speedups
|
||||
geo_table_grouped = compute_geomean_speedups(
|
||||
df,
|
||||
baseline_col="Torch (Compiled)",
|
||||
speedup_cols=["CUDA", "Triton"],
|
||||
groupby_cols=["col_major", "group_shape"],
|
||||
)
|
||||
|
||||
print("Speedup over Torch (Compiled)")
|
||||
print(geo_table_grouped.to_string(index=False))
|
||||
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor,
|
||||
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
|
||||
GROUP_SIZE = 128
|
||||
FLOAT8_T = torch.float8_e4m3fn
|
||||
|
||||
|
||||
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
||||
print(
|
||||
f"Note : The timings reported above is for {cuda_graph_nops} "
|
||||
"consecutive invocations of the benchmarking functions. "
|
||||
f"Please divide by {cuda_graph_nops} for single invocation "
|
||||
"timings."
|
||||
)
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
class ImplType(Enum):
|
||||
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
||||
REFERENCE = 2
|
||||
|
||||
def get_impl(self):
|
||||
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return silu_mul_per_token_group_quant_fp8_colmajor
|
||||
elif self == ImplType.REFERENCE:
|
||||
return reference
|
||||
raise ValueError(f"Unrecognized ImplType {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
output: torch.Tensor
|
||||
|
||||
# Reference act output tensor
|
||||
ref_act_out: torch.Tensor
|
||||
ref_quant_out: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make(T: int, N: int) -> "BenchmarkTensors":
|
||||
assert T % GROUP_SIZE == 0
|
||||
assert N % (GROUP_SIZE * 2) == 0
|
||||
|
||||
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
||||
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
||||
FLOAT8_T
|
||||
)
|
||||
|
||||
# reference output.
|
||||
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||
ref_quant_out = torch.empty(
|
||||
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
||||
).to(FLOAT8_T)
|
||||
|
||||
return BenchmarkTensors(
|
||||
input=input,
|
||||
output=output,
|
||||
ref_act_out=ref_act_out,
|
||||
ref_quant_out=ref_quant_out,
|
||||
)
|
||||
|
||||
@property
|
||||
def T(self):
|
||||
return self.input.size(0)
|
||||
|
||||
@property
|
||||
def N(self):
|
||||
return self.input.size(1)
|
||||
|
||||
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
||||
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||
return {
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
elif impl_type == ImplType.REFERENCE:
|
||||
return {
|
||||
"input": self.input,
|
||||
"act_out": self.ref_act_out,
|
||||
"quant_out": self.ref_quant_out,
|
||||
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||
}
|
||||
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
||||
|
||||
|
||||
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
||||
"""
|
||||
Reference triton quant kernel from,
|
||||
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||
"""
|
||||
assert quant_out.size() == x.size()
|
||||
# Allocate the scale tensor column-major format.
|
||||
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||
x_q = quant_out
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||
|
||||
M = x.numel() // GROUP_SIZE
|
||||
N = GROUP_SIZE
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||
num_stages = 1
|
||||
|
||||
finfo = torch.finfo(FLOAT8_T)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
GROUP_SIZE,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=use_ue8m0,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def reference(
|
||||
input: torch.Tensor,
|
||||
act_out: torch.Tensor,
|
||||
quant_out: torch.Tensor,
|
||||
use_ue8m0: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
torch.ops._C.silu_and_mul(act_out, input)
|
||||
return reference_quant(act_out, quant_out, use_ue8m0)
|
||||
|
||||
|
||||
def bench_impl(
|
||||
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
||||
) -> TMeasurement:
|
||||
T = bench_tensors[0].T
|
||||
N = bench_tensors[0].N
|
||||
|
||||
arg_pool_size = len(bench_tensors)
|
||||
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
||||
|
||||
# warmup
|
||||
for kwargs in kwargs_list:
|
||||
impl_type.get_impl()(**kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Merge into a single kwargs and qualify arguments as ArgPool
|
||||
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
||||
for _kwargs in kwargs_list:
|
||||
for k, v in _kwargs.items():
|
||||
kwargs[k].values.append(v)
|
||||
|
||||
cuda_graph_params = None
|
||||
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
||||
timer = None
|
||||
with Bench(
|
||||
cuda_graph_params,
|
||||
"silu-mul-quant",
|
||||
f"num_tokens={T}, N={N}",
|
||||
impl_type.name,
|
||||
impl_type.get_impl(),
|
||||
**kwargs,
|
||||
) as bench:
|
||||
timer = bench.run()
|
||||
return timer
|
||||
|
||||
|
||||
def test_correctness(T: int, N: int):
|
||||
print(f"Testing num_tokens={T}, N={N} ...")
|
||||
|
||||
bench_tensor = BenchmarkTensors.make(T, N)
|
||||
|
||||
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
||||
|
||||
# reference output
|
||||
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
||||
|
||||
# test ouptut
|
||||
out_q, out_s = output_from_impl(
|
||||
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
||||
torch.testing.assert_close(ref_out_s, out_s)
|
||||
|
||||
|
||||
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
||||
timers = []
|
||||
for N, T in product(Ns, Ts):
|
||||
test_correctness(T, N)
|
||||
|
||||
bench_tensors: list[BenchmarkTensors] = [
|
||||
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
||||
]
|
||||
|
||||
silu_mul_quant_timer = bench_impl(
|
||||
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||
)
|
||||
timers.append(silu_mul_quant_timer)
|
||||
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
||||
timers.append(reference_timer)
|
||||
|
||||
print_timers(
|
||||
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
||||
)
|
||||
|
||||
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
||||
N = [2048, 4096, 8192]
|
||||
|
||||
print(f"T = {T}, N = {N}")
|
||||
run(T, N, arg_pool_size=8)
|
||||
105
benchmarks/kernels/benchmark_activation.py
Normal file
105
benchmarks/kernels/benchmark_activation.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# benchmark custom activation op performance
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.activation # noqa F401
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
intermediate_size = [3072, 9728, 12288]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||
|
||||
|
||||
def benchmark_activation(
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
intermediate_size: int,
|
||||
provider: str,
|
||||
func_name: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = "cuda"
|
||||
num_tokens = batch_size * seq_len
|
||||
dim = intermediate_size
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
if func_name == "gelu_and_mul":
|
||||
layer = CustomOp.op_registry[func_name](approximate="none")
|
||||
elif func_name == "gelu_and_mul_tanh":
|
||||
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
|
||||
elif func_name == "fatrelu_and_mul":
|
||||
threshold = 0.5
|
||||
layer = CustomOp.op_registry[func_name](threshold)
|
||||
else:
|
||||
layer = CustomOp.op_registry[func_name]()
|
||||
|
||||
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
|
||||
compiled_layer = torch.compile(layer.forward_native)
|
||||
|
||||
if provider == "custom":
|
||||
fn = lambda: layer(x)
|
||||
elif provider == "compiled":
|
||||
fn = lambda: compiled_layer(x)
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
fn, quantiles=[0.5, 0.2, 0.8]
|
||||
)
|
||||
return ms, max_ms, min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the custom activation op.")
|
||||
parser.add_argument(
|
||||
"--func-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"mul_and_silu",
|
||||
"silu_and_mul",
|
||||
"gelu_and_mul",
|
||||
"gelu_and_mul_tanh",
|
||||
"fatrelu_and_mul",
|
||||
"swigluoai_and_mul",
|
||||
"gelu_new",
|
||||
"gelu_fast",
|
||||
"quick_gelu",
|
||||
],
|
||||
default="silu_and_mul",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert args
|
||||
|
||||
func_name = args.func_name
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||
|
||||
perf_report = triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "intermediate_size"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["custom", "compiled"],
|
||||
line_names=["Custom OP", "Compiled"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="ms",
|
||||
plot_name=f"{func_name}-op-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
|
||||
perf_report(
|
||||
lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation(
|
||||
batch_size, seq_len, intermediate_size, provider, func_name, dtype
|
||||
)
|
||||
).run(print_data=True)
|
||||
@@ -1,302 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
||||
optimized_dequantize_gemm)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
def torch_mult(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
weights: torch.Tensor,
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
) -> torch.Tensor:
|
||||
output = F.linear(input, weights)
|
||||
return output
|
||||
|
||||
|
||||
def dequant_out_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
if bias is None:
|
||||
output = F.linear(input, weights, bias)
|
||||
orig_shape = output.shape
|
||||
flattened_output = output.view(-1, output.size(-1))
|
||||
f_scales = scales.view(-1, scales.shape[0])
|
||||
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
||||
flattened_output *= b_scales
|
||||
return flattened_output.view(orig_shape)
|
||||
else:
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||
-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_weight_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
||||
-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_no_scale(
|
||||
input: torch.Tensor, # [..., in_features]
|
||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
||||
codebooks: torch.
|
||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
||||
# the generic pytorch version.
|
||||
# Just visual comparison.
|
||||
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
|
||||
|
||||
n = parts.sum().item()
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device)
|
||||
|
||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device)
|
||||
|
||||
count = 0
|
||||
for index in range(16):
|
||||
for i in range(8):
|
||||
for book in range(nbooks):
|
||||
codebooks[book, index, 0, i] = count * (10**book)
|
||||
count += 1
|
||||
|
||||
print("codes shape", codes.shape)
|
||||
|
||||
for i in range(16):
|
||||
for book in range(nbooks):
|
||||
codes[0, i, book] = i
|
||||
codes[0, -i, book] = i
|
||||
|
||||
weights = dequantize_weight(codes, codebooks, None)
|
||||
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
|
||||
|
||||
print("weights shape:", weights.shape)
|
||||
print("weights2 shape:", weights2.shape)
|
||||
|
||||
print("weights are:", weights)
|
||||
print("weights2 are:", weights2)
|
||||
|
||||
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
|
||||
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
|
||||
|
||||
print("last 128 weights are", weights[0, -128:])
|
||||
print("last 128 weights2 are:", weights2[0, -128:])
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument("--nbooks",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of codebooks (default: 1)")
|
||||
parser.add_argument("--bits",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of bits per code element (default: 16)")
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run the decompression/dequant tester rather than benchmarking "
|
||||
"(default: False)")
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract values
|
||||
nbooks = args.nbooks
|
||||
bits = args.bits
|
||||
|
||||
if args.test:
|
||||
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
|
||||
return
|
||||
|
||||
# Otherwise, benchmark.
|
||||
methods = [
|
||||
ops.aqlm_gemm,
|
||||
dequant_out_scale,
|
||||
generic_dequantize_gemm,
|
||||
optimized_dequantize_gemm,
|
||||
dequant_weight_scale,
|
||||
torch_mult,
|
||||
dequant_no_scale,
|
||||
]
|
||||
|
||||
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
|
||||
print(f"writing benchmarks to file {filename}")
|
||||
with open(filename, "w") as f:
|
||||
sys.stdout = f
|
||||
|
||||
print('m | k | n | n parts', end='')
|
||||
for method in methods:
|
||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
|
||||
print('')
|
||||
|
||||
# These are reasonable prefill sizes.
|
||||
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
|
||||
(4096, (11008, 11008)), (11008, (4096, )))
|
||||
|
||||
# reasonable ranges for m.
|
||||
for m in [
|
||||
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
|
||||
128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
||||
]:
|
||||
print(f'{m}', file=sys.__stdout__)
|
||||
for ksp in ksandpartions:
|
||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
|
||||
methods)
|
||||
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
|
||||
methods):
|
||||
|
||||
# I didn't see visible improvements from increasing these, but feel free :)
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
num_calls = 100
|
||||
|
||||
# warmup.
|
||||
for method in methods:
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
n = parts.sum().item()
|
||||
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
|
||||
|
||||
for method in methods:
|
||||
best_time_us = 1e20
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
print(f' | {kernel_dur_us:.0f}', end='')
|
||||
|
||||
print('')
|
||||
|
||||
|
||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
|
||||
nbooks: int, bits: int, method) -> float:
|
||||
|
||||
n = parts.sum().item()
|
||||
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device)
|
||||
|
||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device)
|
||||
|
||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||
|
||||
# for comparison to just a pytorch mult.
|
||||
weights = torch.randn((n, k), dtype=torch.float16, device=device)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
|
||||
if method is torch_mult:
|
||||
for i in range(num_calls):
|
||||
torch_mult(input, weights, scales)
|
||||
else:
|
||||
for i in range(num_calls):
|
||||
method(input, codes, codebooks, scales, parts, None)
|
||||
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
244
benchmarks/kernels/benchmark_bitblas.py
Normal file
244
benchmarks/kernels/benchmark_bitblas.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
)
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target."
|
||||
)
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
type=str,
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_size", type=int, default=None, help="Group size for grouped quantization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||
help="Data type of activation A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W_dtype",
|
||||
type=str,
|
||||
default="int4",
|
||||
choices=[
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"int1",
|
||||
"nf4",
|
||||
"fp4_e2m1",
|
||||
],
|
||||
help="Data type of weight W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accum_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "int32"],
|
||||
help="Data type for accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int32", "int8"],
|
||||
help="Data type for output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="nt",
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_bias", action="store_true", help="Include bias in the benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_zeros", action="store_true", help="Include zeros in the quantization."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["original", "rescale", "quantized"],
|
||||
help="Specify the mode for calculating zeros.",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Assign arguments to variables
|
||||
target = args.target
|
||||
A_dtype = args.A_dtype
|
||||
W_dtype = args.W_dtype
|
||||
accum_dtype = args.accum_dtype
|
||||
out_dtype = args.out_dtype
|
||||
layout = args.layout
|
||||
with_bias = args.with_bias
|
||||
group_size = args.group_size
|
||||
with_scaling = args.with_scaling
|
||||
with_zeros = args.with_zeros
|
||||
zeros_mode = args.zeros_mode
|
||||
|
||||
# Define a list of shared arguments that repeat in every config
|
||||
shared_args = [
|
||||
A_dtype,
|
||||
W_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
layout,
|
||||
with_bias,
|
||||
group_size,
|
||||
with_scaling,
|
||||
with_zeros,
|
||||
zeros_mode,
|
||||
]
|
||||
|
||||
# Define just the (M, K, N) shapes in a more compact list
|
||||
shapes = [
|
||||
# square test
|
||||
(1, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(1, 43008, 14336),
|
||||
(1, 14336, 14336),
|
||||
(1, 57344, 14336),
|
||||
(1, 14336, 57344),
|
||||
# OPT-65B
|
||||
(1, 9216, 9216),
|
||||
(1, 36864, 9216),
|
||||
(1, 9216, 36864),
|
||||
(1, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(1, 8192, 22016),
|
||||
(1, 8192, 8192),
|
||||
(1, 28672, 8192),
|
||||
(1, 8192, 28672),
|
||||
# square test
|
||||
(16384, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(8192, 43008, 14336),
|
||||
(8192, 14336, 14336),
|
||||
(8192, 57344, 14336),
|
||||
(8192, 14336, 57344),
|
||||
# OPT-65B
|
||||
(8192, 9216, 9216),
|
||||
(8192, 36864, 9216),
|
||||
(8192, 9216, 36864),
|
||||
(8192, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(8192, 8192, 22016),
|
||||
(8192, 8192, 8192),
|
||||
(8192, 28672, 8192),
|
||||
(8192, 8192, 28672),
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
|
||||
benchmark_results = {}
|
||||
for config_class, operator, input_args in benchmark_sets:
|
||||
config = config_class(*input_args)
|
||||
matmul = operator(config, target=target, enable_tuning=True)
|
||||
kernel_latency = matmul.profile_latency()
|
||||
|
||||
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||
|
||||
profile_config = {
|
||||
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||
"BitBLAS_top20_latency": kernel_latency,
|
||||
}
|
||||
}
|
||||
|
||||
benchmark_results.update(profile_config)
|
||||
|
||||
# Define headers for the table
|
||||
headers = [
|
||||
"PrimFunc",
|
||||
"Input Arguments",
|
||||
"BitBLAS Top20 Latency",
|
||||
]
|
||||
|
||||
# Calculate column widths for pretty printing
|
||||
col_widths = [0, 0, 0]
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
|
||||
col_widths[2] = max(
|
||||
col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2,
|
||||
)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
# Print header
|
||||
for i, header in enumerate(headers):
|
||||
headers[i] = header.ljust(col_widths[i])
|
||||
print("".join(headers))
|
||||
print("-" * sum(col_widths))
|
||||
|
||||
# Print rows
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
row = [
|
||||
func_name,
|
||||
input_args_str,
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
|
||||
)
|
||||
print(row_str)
|
||||
504
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
504
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
|
||||
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
|
||||
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
|
||||
and 16-bit activations.
|
||||
"""
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"nvidia/DeepSeek-R1-FP4": [
|
||||
[256, 8, 2048, 7168],
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nvidia/DeepSeek-R1-FP4",
|
||||
]
|
||||
|
||||
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
device = "cuda"
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
_, a_fp8_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_fp8q = torch.empty(
|
||||
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
|
||||
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
|
||||
w1_fp8q_notransp = w1_fp8q.clone()
|
||||
w2_fp8q_notransp = w2_fp8q.clone()
|
||||
w1_fp8q = w1_fp8q.transpose(1, 2)
|
||||
w2_fp8q = w2_fp8q.transpose(1, 2)
|
||||
|
||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
quant_blocksize = 16
|
||||
w1_blockscale = torch.empty(
|
||||
(num_experts, 2 * n, k // quant_blocksize),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
w2_blockscale = torch.empty(
|
||||
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# n_b_scales = 2 * n if per_out_ch else 1
|
||||
# k_b_scales = k if per_out_ch else 1
|
||||
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
|
||||
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_e = w1[expert]
|
||||
w2_e = w2[expert]
|
||||
w1_amax = torch.abs(w1_e).max().to(torch.float32)
|
||||
w2_amax = torch.abs(w2_e).max().to(torch.float32)
|
||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||
|
||||
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w1_e, w1_gs[expert]
|
||||
)
|
||||
|
||||
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
||||
w2_e, w2_gs[expert]
|
||||
)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_fp8_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w1_gs: torch.Tensor,
|
||||
w2_gs: torch.Tensor,
|
||||
a1_gs: torch.Tensor,
|
||||
a2_gs: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||
cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_fp4,
|
||||
w2_fp4=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w1_alphas: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_fp4,
|
||||
w2_fp4=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_fp8_scale: torch.Tensor,
|
||||
):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(
|
||||
a,
|
||||
w1_fp8q_notransp,
|
||||
w2_fp8q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_fp8scale,
|
||||
w2_fp8scale,
|
||||
a_fp8_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
"w1_fp8q_notransp": w1_fp8q_notransp,
|
||||
"w2_fp8q_notransp": w2_fp8q_notransp,
|
||||
"w1_fp8scale": w1_fp8scale,
|
||||
"w2_fp8scale": w2_fp8scale,
|
||||
"a_fp8_scale": a_fp8_scale,
|
||||
# Cutlass params
|
||||
"a": a,
|
||||
"a1_gscale": a1_gs,
|
||||
"w1_fp4": w1_fp4,
|
||||
"w1_blockscale": w1_blockscale,
|
||||
"w1_alphas": w1_gs,
|
||||
"a2_gscale": a2_gs,
|
||||
"w2_fp4": w2_fp4,
|
||||
"w2_blockscale": w2_blockscale,
|
||||
"w2_alphas": w2_gs,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"e": num_experts,
|
||||
"device": device,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(
|
||||
a,
|
||||
w1_fp8q_notransp,
|
||||
w2_fp8q_notransp,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_fp8scale,
|
||||
w2_fp8scale,
|
||||
a_fp8_scale,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
|
||||
run_cutlass_moe_fp4(
|
||||
a,
|
||||
w1_fp4,
|
||||
w2_fp4,
|
||||
w1_blockscale,
|
||||
w2_blockscale,
|
||||
w1_gs,
|
||||
w2_gs,
|
||||
a1_gs,
|
||||
a2_gs,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
num_experts,
|
||||
device,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="cutlass_moe_fp4",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="cutlass_moe_fp4_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in args.batch_sizes:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
|
||||
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
|
||||
but use different quantization strategies and backends.
|
||||
"""
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
# Weight shapes for different models: [num_experts, topk, hidden_size,
|
||||
# intermediate_size]
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"mixtral-8x7b": [
|
||||
[8, 2, 4096, 14336],
|
||||
],
|
||||
"deepseek-v2": [
|
||||
[160, 6, 5120, 12288],
|
||||
],
|
||||
"custom-small": [
|
||||
[8, 2, 2048, 7168],
|
||||
],
|
||||
"glm45-fp8": [
|
||||
[128, 8, 4096, 1408],
|
||||
],
|
||||
"Llama-4-Maverick-17B-128E-Instruct-FP8": [
|
||||
[128, 1, 5120, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"mixtral-8x7b",
|
||||
]
|
||||
|
||||
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False, True]
|
||||
PER_OUT_CH_OPTS = [False, True]
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list,
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
device = "cuda"
|
||||
|
||||
# Create input activations
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
|
||||
# Create weights
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
# Create FP8 quantized weights and scales for both kernels
|
||||
w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE)
|
||||
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE)
|
||||
|
||||
# Create scales based on quantization strategy
|
||||
if per_out_ch:
|
||||
# Per-channel quantization
|
||||
w1_scale = torch.empty(
|
||||
(num_experts, 2 * n, 1), device=device, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
# Quantize weights
|
||||
for expert in range(num_experts):
|
||||
if per_out_ch:
|
||||
# Per-channel quantization - not yet implemented properly
|
||||
# For now, fall back to per-tensor quantization
|
||||
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||
# Expand scalar scales to the expected per-channel shape
|
||||
w1_scale[expert] = w1_scale_temp.expand(2 * n, 1)
|
||||
w2_scale[expert] = w2_scale_temp.expand(k, 1)
|
||||
else:
|
||||
# Per-tensor quantization
|
||||
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||
# Store scalar scales in [1, 1] tensors
|
||||
w1_scale[expert, 0, 0] = w1_scale_temp
|
||||
w2_scale[expert, 0, 0] = w2_scale_temp
|
||||
|
||||
# Prepare weights for CUTLASS (no transpose needed)
|
||||
w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K]
|
||||
w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N]
|
||||
|
||||
# Create router scores and get topk
|
||||
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
|
||||
# Force per-tensor quantization for all cases to match working e2e setup
|
||||
a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||
a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||
|
||||
# Force per-tensor quantization for all cases
|
||||
per_act_token = False
|
||||
|
||||
# Create stride tensors for CUTLASS
|
||||
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
|
||||
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
|
||||
cutlass_moe_fp8(
|
||||
a=a,
|
||||
w1_q=w1,
|
||||
w2_q=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
|
||||
# Pre-create quantization config to avoid creating it inside CUDA graph
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
)
|
||||
|
||||
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
# Capture 10 invocations like benchmark_moe.py
|
||||
for _ in range(10):
|
||||
cutlass_moe_fp8(
|
||||
a=a,
|
||||
w1_q=w1_fp8q_cutlass,
|
||||
w2_q=w2_fp8q_cutlass,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
# Capture 10 invocations like benchmark_moe.py
|
||||
for _ in range(10):
|
||||
fused_experts(
|
||||
a,
|
||||
w1_fp8q,
|
||||
w2_fp8q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
|
||||
"""Benchmark CUDA graph using events like benchmark_moe.py"""
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies = []
|
||||
for _ in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
|
||||
# Divide by 10 since graph contains 10 calls
|
||||
return sum(latencies) / (num_iters * 10)
|
||||
|
||||
# Benchmark parameters
|
||||
num_warmup = 5
|
||||
num_iters = 100
|
||||
|
||||
# Benchmark only CUDA graphs (more reliable and faster)
|
||||
# Benchmark Triton MoE with CUDA graphs
|
||||
triton_graph_time = bench_cuda_graph(
|
||||
triton_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||
)
|
||||
|
||||
# Benchmark CUTLASS MoE with CUDA graphs
|
||||
cutlass_graph_time = bench_cuda_graph(
|
||||
cutlass_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||
)
|
||||
|
||||
# Convert ms to us and return results
|
||||
triton_time_us = triton_graph_time * 1000
|
||||
cutlass_time_us = cutlass_graph_time * 1000
|
||||
|
||||
return {
|
||||
"batch_size": m,
|
||||
"triton_time_us": triton_time_us,
|
||||
"cutlass_time_us": cutlass_time_us,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
all_results = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in args.per_act_token_opts:
|
||||
for per_out_ch in args.per_out_ch_opts:
|
||||
print(
|
||||
f"\n=== {model}, experts={num_experts}, topk={topk},"
|
||||
f"per_act={per_act_token}, per_out_ch={per_out_ch} ==="
|
||||
)
|
||||
|
||||
config_results = []
|
||||
for size_m in args.batch_sizes:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
result = bench_run(
|
||||
[], # Not used anymore
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
if result:
|
||||
config_results.append(result)
|
||||
|
||||
# Print results table for this configuration
|
||||
if config_results:
|
||||
print(
|
||||
f"\n{'Batch Size':<12}"
|
||||
f"{'Triton (us)':<15}"
|
||||
f"{'CUTLASS (us)':<15}"
|
||||
)
|
||||
print("-" * 45)
|
||||
for result in config_results:
|
||||
print(
|
||||
f"{result['batch_size']:<12}"
|
||||
f"{result['triton_time_us']:<15.2f}"
|
||||
f"{result['cutlass_time_us']:<15.2f}"
|
||||
)
|
||||
|
||||
all_results.extend(config_results)
|
||||
|
||||
print(f"\nTotal benchmarks completed: {len(all_results)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
|
||||
across specified models/shapes/batches
|
||||
|
||||
Example usage:
|
||||
python benchmark_cutlass_moe_fp8.py \
|
||||
--model "Llama-4-Maverick-17B-128E-Instruct-FP8" \
|
||||
--tp-sizes 8 \
|
||||
--batch-size 2 4 8 \
|
||||
--per-act-token-opts false \
|
||||
--per-out-ch-opts false
|
||||
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument(
|
||||
"--per-act-token-opts",
|
||||
nargs="+",
|
||||
type=lambda x: x.lower() == "true",
|
||||
default=[False, True],
|
||||
help="Per-activation token quantization options (true/false)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per-out-ch-opts",
|
||||
nargs="+",
|
||||
type=lambda x: x.lower() == "true",
|
||||
default=[False, True],
|
||||
help="Per-output channel quantization options (true/false)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
@@ -0,0 +1,508 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Benchmark script for device communicators:
|
||||
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||
and SymmMemCommunicator (multimem, two-shot).
|
||||
|
||||
for NCCL symmetric memory you need to set the environment variables
|
||||
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
||||
not use fast NVLS implementation for all reduce.
|
||||
|
||||
Usage:
|
||||
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||
|
||||
Example:
|
||||
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
||||
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
register_nccl_symmetric_ops,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
set_graph_pool_id,
|
||||
)
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default sequence lengths to benchmark
|
||||
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
# Fixed hidden size and dtype for all benchmarks
|
||||
HIDDEN_SIZE = 8192
|
||||
BENCHMARK_DTYPE = torch.bfloat16
|
||||
|
||||
# CUDA graph settings
|
||||
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
||||
|
||||
|
||||
class CommunicatorBenchmark:
|
||||
"""Benchmark class for testing device communicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
cpu_group: ProcessGroup,
|
||||
sequence_lengths: list[int],
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.device = device
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# Calculate max_size_override based on largest sequence length
|
||||
max_seq_len = max(sequence_lengths)
|
||||
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
||||
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
||||
|
||||
# Initialize communicators
|
||||
self.custom_allreduce = None
|
||||
self.pynccl_comm = None
|
||||
self.symm_mem_comm = None
|
||||
self.symm_mem_comm_multimem = None
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
self._init_communicators()
|
||||
|
||||
def _init_communicators(self):
|
||||
"""Initialize all available communicators."""
|
||||
try:
|
||||
self.custom_allreduce = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=self.max_size_override,
|
||||
)
|
||||
if not self.custom_allreduce.disabled:
|
||||
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
||||
)
|
||||
self.custom_allreduce = None
|
||||
|
||||
try:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group, device=self.device
|
||||
)
|
||||
if not self.pynccl_comm.disabled:
|
||||
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||
else:
|
||||
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||
self.pynccl_comm = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
||||
)
|
||||
self.pynccl_comm = None
|
||||
|
||||
# Initialize variants for SymmMemCommunicator
|
||||
try:
|
||||
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=True,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_multimem.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_multimem = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_multimem = None
|
||||
|
||||
try:
|
||||
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=False,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_two_shot.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_two_shot = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
def benchmark_allreduce(
|
||||
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark allreduce operations for all available communicators."""
|
||||
|
||||
results = {}
|
||||
|
||||
# Define communicators with their benchmark functions
|
||||
communicators = []
|
||||
|
||||
if self.custom_allreduce is not None:
|
||||
comm = self.custom_allreduce
|
||||
# CustomAllreduce one-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_1stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"1stage", # env variable value
|
||||
)
|
||||
)
|
||||
# CustomAllreduce two-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_2stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"2stage", # env variable value
|
||||
)
|
||||
)
|
||||
|
||||
if self.pynccl_comm is not None:
|
||||
comm = self.pynccl_comm
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl-symm",
|
||||
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_multimem is not None:
|
||||
comm = self.symm_mem_comm_multimem
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_multimem",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_two_shot is not None:
|
||||
comm = self.symm_mem_comm_two_shot
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_two_shot",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
# Benchmark each communicator
|
||||
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||
# Set environment variable if needed
|
||||
if env_var is not None:
|
||||
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||
else:
|
||||
# Clear the environment variable to avoid interference
|
||||
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
|
||||
return results
|
||||
|
||||
def benchmark_allreduce_single(
|
||||
self,
|
||||
sequence_length: int,
|
||||
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
|
||||
should_use_fn: Callable[[torch.Tensor], bool],
|
||||
context,
|
||||
num_warmup: int,
|
||||
num_trials: int,
|
||||
) -> float | None:
|
||||
"""Benchmark method with CUDA graph optimization."""
|
||||
try:
|
||||
# Create test tensor (2D: sequence_length x hidden_size)
|
||||
tensor = torch.randn(
|
||||
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
||||
)
|
||||
if not should_use_fn(tensor):
|
||||
return None
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
graph_input = tensor.clone()
|
||||
|
||||
# Warmup before capture
|
||||
for _ in range(3):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
# Capture the graph using context manager
|
||||
with context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
graph_pool = torch.cuda.graph_pool_handle()
|
||||
set_graph_pool_id(graph_pool)
|
||||
with torch.cuda.graph(graph, pool=graph_pool):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for _ in range(num_warmup):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_trials):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
||||
return (
|
||||
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("CUDA graph benchmark failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"CUDA graph benchmark failed for communicator: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
||||
"""Calculate speedup information for a single tensor size."""
|
||||
if not comm_results:
|
||||
return "N/A"
|
||||
|
||||
# Find the fastest communicator
|
||||
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
||||
fastest_time = comm_results[fastest_comm]
|
||||
|
||||
# Calculate speedup vs PyNccl if available
|
||||
if "pynccl" in comm_results:
|
||||
pynccl_time = comm_results["pynccl"]
|
||||
speedup = pynccl_time / fastest_time
|
||||
return f"{fastest_comm} ({speedup:.2f}x)"
|
||||
else:
|
||||
return f"{fastest_comm} (N/A)"
|
||||
|
||||
|
||||
def print_results(
|
||||
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
||||
):
|
||||
"""Print benchmark results in a formatted table."""
|
||||
|
||||
print(f"\n{'=' * 130}")
|
||||
print("Device Communicator Benchmark Results")
|
||||
print(
|
||||
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
||||
f"Hidden Size: {HIDDEN_SIZE}"
|
||||
)
|
||||
print(f"{'=' * 130}")
|
||||
|
||||
# Get all communicator names
|
||||
all_comms = set()
|
||||
for size_results in results.values():
|
||||
all_comms.update(size_results.keys())
|
||||
|
||||
all_comms = sorted(list(all_comms))
|
||||
|
||||
# Print header
|
||||
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
||||
for comm in all_comms:
|
||||
header += f"{comm:<20}"
|
||||
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
# Print results for each sequence length
|
||||
for seq_len in sequence_lengths:
|
||||
if seq_len in results:
|
||||
# Calculate tensor size in elements and bytes
|
||||
tensor_elements = seq_len * HIDDEN_SIZE
|
||||
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
||||
|
||||
# Format tensor size (MB)
|
||||
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
||||
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
||||
|
||||
# Format tensor shape
|
||||
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
||||
|
||||
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
||||
for comm in all_comms:
|
||||
if comm in results[seq_len]:
|
||||
row += f"{results[seq_len][comm]:<20.3f}"
|
||||
else:
|
||||
row += f"{'N/A':<20}"
|
||||
|
||||
# Calculate speedup information
|
||||
speedup_info = _calculate_speedup_info(results[seq_len])
|
||||
row += f"{speedup_info:<30}"
|
||||
|
||||
print(row)
|
||||
|
||||
print(f"{'=' * 130}")
|
||||
print("All times are in milliseconds (ms) per allreduce operation")
|
||||
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence-lengths",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=DEFAULT_SEQUENCE_LENGTHS,
|
||||
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
||||
)
|
||||
|
||||
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize distributed
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="gloo")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Set device
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Get CPU process group
|
||||
cpu_group = dist.new_group(backend="gloo")
|
||||
|
||||
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
||||
# in symm_mem and custom_all_reduce for benchmark
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
# Initialize benchmark
|
||||
benchmark = CommunicatorBenchmark(
|
||||
rank, world_size, device, cpu_group, args.sequence_lengths
|
||||
)
|
||||
|
||||
# Run benchmarks
|
||||
all_results = {}
|
||||
|
||||
for seq_len in args.sequence_lengths:
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
||||
seq_len,
|
||||
seq_len,
|
||||
HIDDEN_SIZE,
|
||||
)
|
||||
|
||||
results = benchmark.benchmark_allreduce(
|
||||
sequence_length=seq_len,
|
||||
num_warmup=args.num_warmup,
|
||||
num_trials=args.num_trials,
|
||||
)
|
||||
|
||||
all_results[seq_len] = results
|
||||
|
||||
# Synchronize between ranks
|
||||
dist.barrier()
|
||||
|
||||
# Print results (only rank 0)
|
||||
if rank == 0:
|
||||
print_results(all_results, args.sequence_lengths, world_size)
|
||||
|
||||
# Save to JSON if requested
|
||||
if args.output_json:
|
||||
# Add speedup information to results
|
||||
enhanced_results = {}
|
||||
for seq_len, comm_results in all_results.items():
|
||||
enhanced_results[seq_len] = {
|
||||
"timings": comm_results,
|
||||
"speedup_info": _calculate_speedup_info(comm_results),
|
||||
}
|
||||
|
||||
output_data = {
|
||||
"world_size": world_size,
|
||||
"dtype": str(BENCHMARK_DTYPE),
|
||||
"hidden_size": HIDDEN_SIZE,
|
||||
"sequence_lengths": args.sequence_lengths,
|
||||
"num_warmup": args.num_warmup,
|
||||
"num_trials": args.num_trials,
|
||||
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
||||
"results": enhanced_results,
|
||||
}
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
logger.info("Results saved to %s", args.output_json)
|
||||
|
||||
# Cleanup
|
||||
if cpu_group != dist.group.WORLD:
|
||||
dist.destroy_process_group(cpu_group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
File diff suppressed because it is too large
Load Diff
427
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
427
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m",
|
||||
"ibm-granite/granite-3.0-3b-a800m",
|
||||
]
|
||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
mkn: tuple[int, int, int],
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||
)
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
_, a_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_q = torch.empty(
|
||||
(num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe(
|
||||
a: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp8(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(
|
||||
a,
|
||||
a_scale,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
# Cutlass params
|
||||
"a_scale": a_scale,
|
||||
"w1_q": w1_q,
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"a": a,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe": run_cutlass_moe,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(
|
||||
a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a_scale,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
run_cutlass_moe(
|
||||
a,
|
||||
a_scale,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in DEFAULT_BATCH_SIZES:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
num_experts,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
mkn,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
94
benchmarks/kernels/benchmark_layernorm.py
Normal file
94
benchmarks/kernels/benchmark_layernorm.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
layer(x, residual)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--add-residual", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(
|
||||
num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
add_residual=args.add_residual,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
1488
benchmarks/kernels/benchmark_lora.py
Normal file
1488
benchmarks/kernels/benchmark_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
745
benchmarks/kernels/benchmark_machete.py
Normal file
745
benchmarks/kernels/benchmark_machete.py
Normal file
@@ -0,0 +1,745 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import pickle as pkl
|
||||
import time
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
marlin_permute_scales,
|
||||
marlin_zero_points,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
|
||||
|
||||
if NVTX_PROFILE:
|
||||
import nvtx
|
||||
|
||||
|
||||
def terse_type_name(dt):
|
||||
return {
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float16: "fp16",
|
||||
torch.int8: "int8",
|
||||
torch.float8_e4m3fn: "fp8",
|
||||
torch.float: "float",
|
||||
torch.int: "int",
|
||||
}[dt]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkTensors:
|
||||
w_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
|
||||
w_q: torch.Tensor
|
||||
group_size: int | None
|
||||
wtype: ScalarType
|
||||
w_g_s: torch.Tensor
|
||||
w_g_zp: torch.Tensor | None
|
||||
w_ch_s: torch.Tensor | None
|
||||
w_tok_s: torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: torch.dtype | None
|
||||
group_scale_type: torch.dtype | None
|
||||
group_zero_type: torch.dtype | None
|
||||
channel_scale_type: torch.dtype | None
|
||||
token_scale_type: torch.dtype | None
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16, scale=1):
|
||||
if dtype.is_floating_point:
|
||||
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
|
||||
else:
|
||||
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: torch.dtype | None,
|
||||
group_size: int | None,
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True,
|
||||
)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
return w_ref, w_q, w_s, w_zp
|
||||
|
||||
|
||||
def create_bench_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||
) -> list[BenchmarkTensors]:
|
||||
m, n, k = shape
|
||||
|
||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||
# so we target total weight size > 2*50mb
|
||||
num_weights = math.ceil(
|
||||
2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
|
||||
)
|
||||
|
||||
a = rand_data((m, k), types.act_type, scale=5)
|
||||
|
||||
benchmark_tensors: list[BenchmarkTensors] = []
|
||||
for _ in range(num_weights):
|
||||
w = rand_data((k, n), types.act_type, scale=5)
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||
a.dtype,
|
||||
w,
|
||||
types.weight_type,
|
||||
types.group_scale_type,
|
||||
group_size,
|
||||
types.group_zero_type is not None,
|
||||
)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = (
|
||||
None
|
||||
if types.channel_scale_type is None
|
||||
else rand_data((n,), types.channel_scale_type)
|
||||
)
|
||||
w_tok_s = (
|
||||
None
|
||||
if types.token_scale_type is None
|
||||
else rand_data((m,), types.token_scale_type)
|
||||
)
|
||||
|
||||
benchmark_tensors.append(
|
||||
BenchmarkTensors(
|
||||
w_ref=w_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
wtype=types.weight_type,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=w_zp,
|
||||
group_size=group_size,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
)
|
||||
|
||||
return benchmark_tensors
|
||||
|
||||
|
||||
def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
a = bt.a
|
||||
w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
|
||||
if a.dtype not in [torch.float16, torch.bfloat16]:
|
||||
a = a.to(torch.float16)
|
||||
w = w.to(torch.float16)
|
||||
return lambda: torch.matmul(a, w)
|
||||
|
||||
|
||||
def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
if bt.w_ch_s is not None and bt.w_tok_s is not None:
|
||||
scale_a = bt.w_tok_s.to(torch.float32)
|
||||
scale_b = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||
return lambda: ops.cutlass_scaled_mm(
|
||||
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
|
||||
)
|
||||
|
||||
|
||||
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
device = bt.a.device
|
||||
|
||||
workspace = MarlinWorkspace(
|
||||
bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
|
||||
if bt.w_g_zp is None:
|
||||
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||
else:
|
||||
w_zp = marlin_zero_points(
|
||||
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||
)
|
||||
|
||||
if bt.group_size is None:
|
||||
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||
else:
|
||||
w_s = marlin_permute_scales(
|
||||
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
|
||||
)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||
w_q = ops.gptq_marlin_repack(
|
||||
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||
)
|
||||
|
||||
if bt.a.dtype.is_floating_point:
|
||||
assert bt.w_ch_s is None
|
||||
assert bt.w_tok_s is None
|
||||
assert bt.group_size is not None
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_bias=None,
|
||||
b_scales=w_s,
|
||||
a_scales=None,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
workspace=workspace.scratch,
|
||||
b_q_type=bt.wtype,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
is_k_full=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
raise NotImplementedError("QQQ is not supported anymore")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def machete_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.machete_prepack_B(
|
||||
w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
|
||||
)
|
||||
|
||||
w_g_zp = bt.w_g_zp
|
||||
if w_g_zp is not None:
|
||||
w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
|
||||
|
||||
return lambda: ops.machete_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_type=bt.wtype,
|
||||
b_group_scales=bt.w_g_s,
|
||||
b_group_zeros=w_g_zp,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
out_type=out_type,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
# expects fp8 scales
|
||||
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||
|
||||
return lambda: ops.cutlass_w4a8_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
maybe_schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
# bench
|
||||
|
||||
|
||||
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
|
||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||
res = TBenchmark.Timer(
|
||||
stmt="""
|
||||
for fn in fns:
|
||||
fn()
|
||||
""",
|
||||
globals={"fns": fns},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
if NVTX_PROFILE:
|
||||
with (
|
||||
nvtx.annotate("mm-bench"),
|
||||
nvtx.annotate(f"{label}|{sub_label}|{description}"),
|
||||
):
|
||||
fns[0]()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
|
||||
|
||||
|
||||
def bench(
|
||||
types: TypeConfig,
|
||||
group_size: int,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
sweep_schedules: bool = True,
|
||||
) -> list[TMeasurement]:
|
||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||
sub_label += f", L={len(benchmark_tensors)}"
|
||||
|
||||
name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
|
||||
if types.group_scale_type is not None:
|
||||
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||
if types.group_zero_type is not None:
|
||||
name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
|
||||
if group_size is not None:
|
||||
name_type_string += f"-G{group_size}"
|
||||
if types.channel_scale_type is not None:
|
||||
name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
|
||||
if types.token_scale_type is not None:
|
||||
name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
"torch.matmul (fp16)",
|
||||
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
|
||||
[cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
if types.act_type != torch.float8_e4m3fn:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"marlin ({name_type_string})",
|
||||
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||
)
|
||||
)
|
||||
|
||||
# machete
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"machete ({name_type_string})",
|
||||
[
|
||||
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass w4a8
|
||||
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass w4a8 ({name_type_string})",
|
||||
[
|
||||
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
print("Finding best schedule for machete")
|
||||
best = None
|
||||
best_schedule = None
|
||||
schedules = ops.machete_supported_schedules(
|
||||
a_type=types.act_type,
|
||||
b_type=types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_zero_type,
|
||||
token_scales_type=types.token_scale_type,
|
||||
channel_scales_type=types.channel_scale_type,
|
||||
out_type=types.output_type,
|
||||
)
|
||||
|
||||
if schedules is None or len(schedules) == 0:
|
||||
raise ValueError("No schedules found to sweep")
|
||||
|
||||
for schedule in reversed(schedules):
|
||||
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||
|
||||
# Prune known bad schedules
|
||||
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||
continue
|
||||
|
||||
res = bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
"machete_best",
|
||||
[
|
||||
machete_create_bench_fn(
|
||||
bt, out_type=types.output_type, schedule=schedule
|
||||
)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
|
||||
results_row = {
|
||||
"M": m,
|
||||
"K": k,
|
||||
"N": n,
|
||||
"group_size": group_size,
|
||||
"schedule": schedule,
|
||||
"median": res.median,
|
||||
}
|
||||
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
|
||||
_SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||
|
||||
print(f" {res.median:5.5} ", schedule)
|
||||
if not best or res.median < best.median:
|
||||
best = res
|
||||
best_schedule = schedule
|
||||
print("Best schedule:", best_schedule)
|
||||
timers.append(best)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: list[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
types = TypeConfig(
|
||||
act_type=args.act_type,
|
||||
weight_type=scalar_types.uint4b8
|
||||
if args.group_zero_type is None
|
||||
else scalar_types.uint4,
|
||||
output_type=args.out_type,
|
||||
group_scale_type=args.group_scale_type,
|
||||
group_zero_type=args.group_zero_type,
|
||||
channel_scale_type=args.channel_scale_type,
|
||||
token_scale_type=args.token_scale_type,
|
||||
)
|
||||
|
||||
results: list[TMeasurement] = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(
|
||||
types,
|
||||
args.group_size,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"{args.act_type}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
sweep_schedules=args.sweep_schedules,
|
||||
)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(
|
||||
data: list[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
|
||||
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
|
||||
m_increment, k_increment, n_increment = (
|
||||
int(x) for x in args.dim_increment.split(",")
|
||||
)
|
||||
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||
MKNs = list(product(Ms, Ks, Ns))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
type_string = f"{args.act_type}"
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {type_string} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestr = time.strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
all_results = []
|
||||
for d in model_bench_data:
|
||||
all_results.extend(d)
|
||||
|
||||
# pickle all data
|
||||
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||
args_dict = vars(args)
|
||||
args_dict.pop("func")
|
||||
pkl.dump(
|
||||
{
|
||||
"args": args_dict,
|
||||
"results": all_results,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
return {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float16": torch.float16,
|
||||
"int8": torch.int8,
|
||||
"float8_e4m3fn": torch.float8_e4m3fn,
|
||||
"int": torch.int,
|
||||
"float": torch.float,
|
||||
}[dt]
|
||||
|
||||
class ToTorchDtype(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Machete GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--act-type",
|
||||
action=ToTorchDtype,
|
||||
required=True,
|
||||
choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-zero-type",
|
||||
type=to_torch_dtype,
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--channel-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=["float"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-scale-type",
|
||||
action=ToTorchDtype,
|
||||
choices=["float"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-type",
|
||||
action=ToTorchDtype,
|
||||
choices=["bfloat16", "float16"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
help="Available options are ['None', '-1', '128'], default=128",
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep-schedules",
|
||||
action="store_true",
|
||||
help="Run a sweep over all supported schedules",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep-csv-out",
|
||||
help="CSV to store sweep results",
|
||||
default="sch_sweep_results.csv",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument(
|
||||
"--dim-start",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Start value for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.add_argument(
|
||||
"--dim-end",
|
||||
type=str,
|
||||
required=True,
|
||||
help="End value (inclusive) for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.add_argument(
|
||||
"--dim-increment",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Increment value for M,K,N as common separated list",
|
||||
)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||
)
|
||||
model_parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
|
||||
args.func(args)
|
||||
|
||||
if _SWEEP_SCHEDULES_RESULTS is not None:
|
||||
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
|
||||
413
benchmarks/kernels/benchmark_marlin.py
Normal file
413
benchmarks/kernels/benchmark_marlin.py
Normal file
@@ -0,0 +1,413 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_fp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(
|
||||
results: list[benchmark.Measurement],
|
||||
model: str,
|
||||
act_order: bool,
|
||||
is_k_full: bool,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
size_m: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
||||
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
||||
)
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
if act_order and (group_size == -1 or group_size == size_k or has_zp):
|
||||
return
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
marlin_24_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
repack_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
allspark_supported = (
|
||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1
|
||||
and not act_order
|
||||
and is_k_full
|
||||
)
|
||||
|
||||
def gen_marlin_params():
|
||||
# Marlin quant
|
||||
marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b.T, group_size
|
||||
)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128] or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
|
||||
elif group_size == 16:
|
||||
return
|
||||
elif has_zp:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b, quant_type, group_size
|
||||
)
|
||||
else:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
|
||||
marlin_quantize(b, quant_type, group_size, act_order)
|
||||
)
|
||||
return (
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
)
|
||||
|
||||
def gen_marlin_24_params():
|
||||
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
|
||||
if marlin_24_supported:
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
)
|
||||
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
|
||||
|
||||
def gen_repack_params():
|
||||
q_w_gptq = None
|
||||
repack_sort_indices = None
|
||||
if repack_supported:
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
return q_w_gptq, repack_sort_indices
|
||||
|
||||
def gen_allspark_params():
|
||||
qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
|
||||
CUBLAS_M_THRESHOLD
|
||||
) = None
|
||||
nonlocal allspark_supported
|
||||
if allspark_supported:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
allspark_supported = allspark_supported and supported_arch
|
||||
if supported_arch:
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
return (
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
sm_count,
|
||||
sm_version,
|
||||
CUBLAS_M_THRESHOLD,
|
||||
)
|
||||
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
) = gen_marlin_params()
|
||||
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
|
||||
gen_marlin_24_params()
|
||||
)
|
||||
q_w_gptq, repack_sort_indices = gen_repack_params()
|
||||
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
|
||||
gen_allspark_params()
|
||||
)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"quant_type": quant_type,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
"size_k": size_k,
|
||||
"a": a,
|
||||
# Marlin params
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_s2": marlin_s2,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
"is_k_full": is_k_full,
|
||||
# Marlin_24 params
|
||||
"marlin_24_w_ref": marlin_24_w_ref,
|
||||
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||
"marlin_24_meta": marlin_24_meta,
|
||||
"marlin_24_s": marlin_24_s,
|
||||
"marlin_24_workspace": marlin_24_workspace,
|
||||
# GPTQ params
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# AllSpark W8A16 params
|
||||
"qw_reorder": qw_reorder,
|
||||
"s_reorder": s_reorder,
|
||||
"zp_reorder": zp_reorder,
|
||||
"sm_count": sm_count,
|
||||
"sm_version": sm_version,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||
}
|
||||
|
||||
min_run_time = 1
|
||||
|
||||
# Warmup pytorch
|
||||
for _ in range(5):
|
||||
torch.matmul(a, marlin_w_ref)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="torch.matmul(a, marlin_w_ref)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="pytorch_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if marlin_24_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_24_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if repack_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if allspark_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="allspark_w8a16_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for layer in WEIGHT_SHAPES[model]:
|
||||
size_k = layer[0]
|
||||
size_n = layer[1]
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for act_order in ACT_ORDER_OPTS:
|
||||
if (
|
||||
len(args.limit_act_order) > 0
|
||||
and act_order not in args.limit_act_order
|
||||
):
|
||||
continue
|
||||
|
||||
for is_k_full in K_FULL_OPTS:
|
||||
if (
|
||||
len(args.limit_k_full) > 0
|
||||
and is_k_full not in args.limit_k_full
|
||||
):
|
||||
continue
|
||||
|
||||
for quant_type in query_marlin_supported_quant_types():
|
||||
if (
|
||||
len(args.limit_num_bits) > 0
|
||||
and quant_type.size_bits not in args.limit_num_bits
|
||||
):
|
||||
continue
|
||||
|
||||
for group_size in (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES
|
||||
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
if (
|
||||
len(args.limit_group_size) > 0
|
||||
and group_size not in args.limit_group_size
|
||||
):
|
||||
continue
|
||||
|
||||
# For act_order, the group_size must be less than
|
||||
# size_k
|
||||
if act_order and (group_size == size_k or group_size == -1):
|
||||
continue
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(
|
||||
results,
|
||||
model,
|
||||
act_order,
|
||||
is_k_full,
|
||||
quant_type,
|
||||
group_size,
|
||||
size_m,
|
||||
size_k,
|
||||
size_n,
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
# For quick benchmarking use:
|
||||
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||
)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,215 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
||||
get_config_file_name)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
def main(dtype: str):
|
||||
method = fused_moe
|
||||
for bs in [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
]:
|
||||
run_grid(bs, method=method, dtype=dtype)
|
||||
|
||||
|
||||
def run_grid(bs, method, dtype: str):
|
||||
d_model = 4096
|
||||
num_total_experts = 8
|
||||
top_k = 2
|
||||
tp_size = 2
|
||||
model_intermediate_size = 14336
|
||||
num_layers = 32
|
||||
num_calls = 100
|
||||
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
configs = []
|
||||
|
||||
for block_size_n in [32, 64, 128, 256]:
|
||||
for block_size_m in [16, 32, 64, 128, 256]:
|
||||
for block_size_k in [64, 128, 256]:
|
||||
for group_size_m in [1, 16, 32, 64]:
|
||||
for num_warps in [4, 8]:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
|
||||
best_config = None
|
||||
best_time_us = 1e20
|
||||
|
||||
print(f'{tp_size=} {bs=}')
|
||||
|
||||
for config in tqdm(configs):
|
||||
# warmup
|
||||
try:
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
continue
|
||||
|
||||
# trial
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
model_dur_ms = kernel_dur_ms * num_layers
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_config = config
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
tqdm.write(
|
||||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
||||
|
||||
print("best_time_us", best_time_us)
|
||||
print("best_config", best_config)
|
||||
|
||||
# holds Dict[str, Dict[str, int]]
|
||||
filename = get_config_file_name(num_total_experts,
|
||||
model_intermediate_size // tp_size,
|
||||
"float8" if dtype == "float8" else None)
|
||||
print(f"writing config to file {filename}")
|
||||
existing_content = {}
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
existing_content = json.load(f)
|
||||
existing_content[str(bs)] = best_config
|
||||
with open(filename, "w") as f:
|
||||
json.dump(existing_content, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
||||
config, dtype: str) -> float:
|
||||
shard_intermediate_size = model_intermediate_size // tp_size
|
||||
|
||||
hidden_states = torch.rand(
|
||||
(bs, d_model),
|
||||
device="cuda:0",
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
w1 = torch.rand(
|
||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
w2 = torch.rand(
|
||||
(num_total_experts, d_model, shard_intermediate_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
if dtype == "float8":
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1_scale = torch.ones(num_total_experts,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.ones(num_total_experts,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32)
|
||||
a1_scale = torch.ones(1,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32)
|
||||
a2_scale = torch.ones(1,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32)
|
||||
|
||||
gating_output = F.softmax(torch.rand(
|
||||
(num_calls, bs, num_total_experts),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
dim=-1)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for i in range(num_calls):
|
||||
hidden_states = method(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
gating_output=gating_output[i],
|
||||
topk=2,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
override_config=config,
|
||||
use_fp8=dtype == "float8",
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='benchmark_mixtral_moe',
|
||||
description='Benchmark and tune the fused_moe kernel',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['float8', 'float16'],
|
||||
help='Data type used for fused_moe kernel computations',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
sys.exit(main(args.dtype))
|
||||
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||
in MLA (Multi-head Latent Attention) prefill.
|
||||
|
||||
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||
various batch sizes, not just the originally tested batch size of 32768.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
# DeepSeek-V3 MLA dimensions
|
||||
NUM_HEADS = 128
|
||||
QK_NOPE_HEAD_DIM = 128
|
||||
PE_DIM = 64
|
||||
|
||||
|
||||
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Original torch.cat approach with expand."""
|
||||
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
|
||||
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||
k = torch.empty(
|
||||
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||
dtype=k_nope.dtype,
|
||||
device=k_nope.device,
|
||||
)
|
||||
k[..., : k_nope.shape[-1]] = k_nope
|
||||
k[..., k_nope.shape[-1] :] = k_pe
|
||||
return k
|
||||
|
||||
|
||||
def benchmark_method(
|
||||
method: Callable,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
num_warmup: int = 10,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||
# Warmup
|
||||
for _ in range(num_warmup):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
start = time.perf_counter()
|
||||
for _ in range(num_iters):
|
||||
_ = method(k_nope, k_pe)
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
return (end - start) / num_iters * 1000 # Convert to ms
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||
"""Run benchmark for a specific dtype."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||
|
||||
print("=" * 80)
|
||||
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||
f"k_pe=[B, 1, {PE_DIM}]"
|
||||
)
|
||||
print(f"dtype: {dtype_name}")
|
||||
print()
|
||||
print(
|
||||
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||
)
|
||||
print("-" * 70)
|
||||
|
||||
results = []
|
||||
for batch_size in batch_sizes:
|
||||
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||
k_nope = torch.randn(
|
||||
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
k_pe = torch.randn(
|
||||
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||
).to(dtype)
|
||||
|
||||
# Benchmark both methods
|
||||
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||
|
||||
speedup = cat_time / direct_time
|
||||
reduction = (1 - direct_time / cat_time) * 100
|
||||
|
||||
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||
|
||||
print(
|
||||
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||
)
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
# Summary statistics
|
||||
speedups = [r[3] for r in results]
|
||||
print("\nSpeedup summary:")
|
||||
print(f" Min: {min(speedups):.2f}x")
|
||||
print(f" Max: {max(speedups):.2f}x")
|
||||
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||
|
||||
# Find crossover point
|
||||
crossover_batch = None
|
||||
for batch_size, _, _, speedup, _ in results:
|
||||
if speedup >= 1.0:
|
||||
crossover_batch = batch_size
|
||||
break
|
||||
|
||||
print("\nConclusion:")
|
||||
if crossover_batch:
|
||||
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||
# Filter for large batches (>= 512 which is typical for prefill)
|
||||
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||
if large_batch_speedups:
|
||||
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Test bfloat16
|
||||
print("\n")
|
||||
run_benchmark(torch.bfloat16, "bfloat16")
|
||||
|
||||
# Test float8_e4m3fn
|
||||
print("\n")
|
||||
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
790
benchmarks/kernels/benchmark_moe.py
Normal file
790
benchmarks/kernels/benchmark_moe.py
Normal file
@@ -0,0 +1,790 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator, text):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
|
||||
text, numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
E = num_experts
|
||||
N = shard_intermediate_size // 2
|
||||
K = hidden_size
|
||||
factor_for_scale = 1e-2
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
w1_scale = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_scale = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
* factor_for_scale
|
||||
)
|
||||
else:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
|
||||
if use_fp8_w8a8:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif use_int8_w8a16:
|
||||
quant_dtype = torch.int8
|
||||
else:
|
||||
quant_dtype = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=quant_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, renormalize=not use_deep_gemm
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def get_rocm_tuning_space(use_fp16):
|
||||
block_mn_range = [16, 32, 64, 128, 256]
|
||||
block_k_range = [16, 32, 64, 128, 256]
|
||||
if not use_fp16:
|
||||
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
num_stage_range = [2]
|
||||
waves_per_eu_range = [0, 1, 2, 4]
|
||||
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||
kpack_range = [1, 2] if use_fp16 else []
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_mn_range,
|
||||
"BLOCK_SIZE_N": block_mn_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
if use_fp16:
|
||||
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||
param_ranges["kpack"] = kpack_range
|
||||
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
|
||||
configs: list[BenchmarkConfig] = []
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
else:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
}
|
||||
|
||||
keys, values = zip(*param_ranges.items())
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
|
||||
# Remove configs that are not compatible with fp8 block quantization
|
||||
# BLOCK_SIZE_K must be a multiple of block_k
|
||||
# BLOCK_SIZE_N must be a multiple of block_n
|
||||
if block_quant_shape is not None and not use_fp16:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
for config in configs[:]:
|
||||
if (
|
||||
config["BLOCK_SIZE_K"] % block_k != 0
|
||||
or config["BLOCK_SIZE_N"] % block_n != 0
|
||||
):
|
||||
configs.remove(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(
|
||||
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
||||
):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(
|
||||
num_tokens * topk, N1, K1, search_space, is_fp16
|
||||
)
|
||||
pruned_space_2 = prune_rocm_configs(
|
||||
num_tokens * topk, N2, K2, search_space, is_fp16
|
||||
)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
|
||||
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 2 if is_fp16 else 1
|
||||
elemBytes_b = 2 if is_fp16 else 1
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
|
||||
if is_fp16:
|
||||
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
if matrix_instr_nonkdim > mfma:
|
||||
continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if is_fp16:
|
||||
if (
|
||||
matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N
|
||||
):
|
||||
continue
|
||||
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||
continue
|
||||
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||
)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def merge_unique_dicts(list1, list2):
|
||||
result = []
|
||||
combined_list = list1.copy()
|
||||
combined_list.extend(list2)
|
||||
for dictionary in combined_list:
|
||||
if dictionary not in result:
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
block_n = block_quant_shape[0] if block_quant_shape else None
|
||||
block_k = block_quant_shape[1] if block_quant_shape else None
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
block_quant_shape,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
use_deep_gemm: bool,
|
||||
) -> dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(
|
||||
num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
search_space,
|
||||
is_fp16,
|
||||
topk,
|
||||
)
|
||||
|
||||
need_device_guard = False
|
||||
if current_platform.is_rocm():
|
||||
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
|
||||
if visible_device != f"{self.device_id}":
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
**(
|
||||
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
||||
if "matrix_instr_nonkdim" in config
|
||||
else {}
|
||||
),
|
||||
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
filename = os.path.join(save_dir, filename)
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get("weight_block_size", default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
||||
if args.model_prefix:
|
||||
config = getattr(config, args.model_prefix)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
"NemotronHForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in (
|
||||
"Qwen2MoeForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3NextForCausalLM",
|
||||
):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
|
||||
text_config = config.get_text_config()
|
||||
E = text_config.num_experts
|
||||
topk = text_config.num_experts_per_tok
|
||||
intermediate_size = text_config.moe_intermediate_size
|
||||
hidden_size = text_config.hidden_size
|
||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.moe_topk[0]
|
||||
intermediate_size = config.moe_intermediate_size[0]
|
||||
hidden_size = config.hidden_size
|
||||
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
|
||||
E = config.thinker_config.text_config.num_experts
|
||||
topk = config.thinker_config.text_config.num_experts_per_tok
|
||||
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
|
||||
hidden_size = config.thinker_config.text_config.hidden_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
hidden_size = config.hidden_size
|
||||
enable_ep = bool(args.enable_expert_parallel)
|
||||
if enable_ep:
|
||||
ensure_divisibility(E, args.tp_size, "Number of experts")
|
||||
E = E // args.tp_size
|
||||
shard_intermediate_size = 2 * intermediate_size
|
||||
else:
|
||||
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||
logger.warning(
|
||||
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
||||
)
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
if use_deep_gemm:
|
||||
raise ValueError(
|
||||
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
|
||||
"kernels. Please remove the flag."
|
||||
)
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
search_space,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(
|
||||
best_configs,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
args.save_dir,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
print(f"Kernel time: {kernel_time:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
)
|
||||
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-dir", type=str, default="./", help="Directory to save tuned results"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
87
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
87
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
return torch.stack(
|
||||
[
|
||||
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||
for _ in range(num_tokens)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# test configurations
|
||||
num_tokens_range = [1, 16, 256, 4096]
|
||||
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||
topk_range = [1, 2, 8]
|
||||
ep_size_range = [1, 8]
|
||||
configs = list(
|
||||
itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_experts", "topk", "ep_size"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm"],
|
||||
line_names=["vLLM"],
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(num_tokens, num_experts, topk, ep_size, provider):
|
||||
"""Benchmark function for Triton."""
|
||||
block_size = 256
|
||||
torch.cuda.manual_seed_all(0)
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
e_map = None
|
||||
if ep_size != 1:
|
||||
local_e = num_experts // ep_size
|
||||
e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e]
|
||||
e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size(
|
||||
topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num_experts",
|
||||
type=int,
|
||||
default=64,
|
||||
choices=[8, 16, 32, 64, 128, 256],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=8,
|
||||
choices=[2, 4, 8],
|
||||
help="Top-k value for correctness check.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
428
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
428
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute,
|
||||
_moe_unpermute_and_reduce,
|
||||
moe_permute,
|
||||
moe_unpermute,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
def benchmark_permute(
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
# output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(
|
||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def benchmark_unpermute(
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
use_customized_permute: bool = False,
|
||||
) -> float:
|
||||
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
output_hidden_states = torch.empty_like(hidden_states)
|
||||
if use_fp8_w8a8:
|
||||
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||
else:
|
||||
align_block_size = None
|
||||
qhidden_states = hidden_states
|
||||
|
||||
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
qhidden_states, input_gating, topk, False
|
||||
)
|
||||
|
||||
def prepare():
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = moe_permute(
|
||||
qhidden_states,
|
||||
a1q_scale=None,
|
||||
topk_ids=topk_ids,
|
||||
n_expert=num_experts,
|
||||
expert_map=None,
|
||||
align_block_size=align_block_size,
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
permuted_hidden_states.to(dtype),
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_qhidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = _moe_permute(
|
||||
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||
)
|
||||
# convert to fp16/bf16 as gemm output
|
||||
return (
|
||||
permuted_qhidden_states.to(dtype),
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
)
|
||||
|
||||
def run(input: tuple):
|
||||
if use_customized_permute:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
first_token_off,
|
||||
inv_perm_idx,
|
||||
m_indices,
|
||||
) = input
|
||||
output = torch.empty_like(hidden_states)
|
||||
moe_unpermute(
|
||||
output,
|
||||
permuted_hidden_states,
|
||||
topk_weights,
|
||||
inv_perm_idx,
|
||||
first_token_off,
|
||||
)
|
||||
else:
|
||||
(
|
||||
permuted_hidden_states,
|
||||
a1q_scale,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
inv_perm,
|
||||
) = input
|
||||
_moe_unpermute_and_reduce(
|
||||
output_hidden_states,
|
||||
permuted_hidden_states,
|
||||
inv_perm,
|
||||
topk_weights,
|
||||
True,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
input = prepare()
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run(input)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
def __init__(self, seed: int) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_customized_permute: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
|
||||
permute_time = benchmark_permute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
unpermute_time = benchmark_unpermute(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
use_customized_permute=use_customized_permute,
|
||||
)
|
||||
return permute_time, unpermute_time
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get("weight_block_size", default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
use_customized_permute = args.use_customized_permute
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
|
||||
ray.init()
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_customized_permute,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}")
|
||||
print(f"Permute time: {permute:.2f} us")
|
||||
print(f"Unpermute time: {unpermute:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-customized-permute", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
322
benchmarks/kernels/benchmark_mrope.py
Normal file
322
benchmarks/kernels/benchmark_mrope.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||
#
|
||||
# The CSV file (named with current date/time) contains these columns:
|
||||
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||
# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
|
||||
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||
# speedup
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Single model benchmark:
|
||||
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
|
||||
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models benchmark:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different TP sizes:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different token counts:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistics from a list of times."""
|
||||
times_array = np.array(times)
|
||||
return {
|
||||
"mean": np.mean(times_array),
|
||||
"median": np.median(times_array),
|
||||
"p99": np.percentile(times_array, 99),
|
||||
"min": np.min(times_array),
|
||||
"max": np.max(times_array),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_mrope(
|
||||
model_name: str,
|
||||
num_tokens: int,
|
||||
head_dim: int,
|
||||
tp_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 8192,
|
||||
is_neox_style: bool = True,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
seed: int = 0,
|
||||
warmup_iter: int = 10,
|
||||
benchmark_iter: int = 100,
|
||||
csv_writer=None,
|
||||
):
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
print(80 * "=")
|
||||
print(
|
||||
f"Evaluating model: {model_name} "
|
||||
f"with tp_size: {tp_size} "
|
||||
f"and num_tokens: {num_tokens}, "
|
||||
f"dtype: {dtype}"
|
||||
)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Warm up
|
||||
for _ in range(warmup_iter):
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Time reference implementation
|
||||
torch_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch_times.append(time.time() - start_time)
|
||||
|
||||
# Time triton kernel implementation
|
||||
triton_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
triton_times.append(time.time() - start_time)
|
||||
|
||||
# Calculate statistics
|
||||
torch_stats = calculate_stats(torch_times)
|
||||
triton_stats = calculate_stats(triton_times)
|
||||
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||
|
||||
print(
|
||||
f"Torch implementation: "
|
||||
f"mean={torch_stats['mean']:.8f}s, "
|
||||
f"median={torch_stats['median']:.8f}s, "
|
||||
f"p99={torch_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton implementation: "
|
||||
f"mean={triton_stats['mean']:.8f}s, "
|
||||
f"median={triton_stats['median']:.8f}s, "
|
||||
f"p99={triton_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
if csv_writer:
|
||||
row = [
|
||||
model_name,
|
||||
tp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
str(rope_parameters),
|
||||
str(dtype).split(".")[-1],
|
||||
torch_stats["mean"],
|
||||
torch_stats["median"],
|
||||
torch_stats["p99"],
|
||||
torch_stats["min"],
|
||||
torch_stats["max"],
|
||||
triton_stats["mean"],
|
||||
triton_stats["median"],
|
||||
triton_stats["p99"],
|
||||
triton_stats["min"],
|
||||
triton_stats["max"],
|
||||
torch_stats["mean"] / triton_stats["mean"], # speedup
|
||||
]
|
||||
csv_writer.writerow(row)
|
||||
|
||||
return torch_stats, triton_stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default="")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
# Create CSV file for results
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
|
||||
|
||||
with open(csv_filename, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
# Write header
|
||||
header = [
|
||||
"model_name",
|
||||
"tp_size",
|
||||
"num_tokens",
|
||||
"num_heads",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_position",
|
||||
"is_neox_style",
|
||||
"rope_parameters",
|
||||
"dtype",
|
||||
"torch_mean",
|
||||
"torch_median",
|
||||
"torch_p99",
|
||||
"torch_min",
|
||||
"torch_max",
|
||||
"triton_mean",
|
||||
"triton_median",
|
||||
"triton_p99",
|
||||
"triton_min",
|
||||
"triton_max",
|
||||
"speedup",
|
||||
]
|
||||
csv_writer.writerow(header)
|
||||
|
||||
model_tp_dict = {}
|
||||
if args.model_name == "":
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||
}
|
||||
else:
|
||||
model_tp_dict[args.model_name] = [args.tp_size]
|
||||
|
||||
if args.num_tokens is None:
|
||||
num_tokens_list = [2**i for i in range(0, 18)]
|
||||
else:
|
||||
num_tokens_list = args.num_tokens
|
||||
|
||||
for model_name, tp_list in model_tp_dict.items():
|
||||
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
|
||||
for tp_size in tp_list:
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
is_neox_style = True
|
||||
rope_parameters = config.rope_parameters
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
for num_tokens in num_tokens_list:
|
||||
benchmark_mrope(
|
||||
model_name=model_name,
|
||||
num_tokens=num_tokens,
|
||||
head_dim=head_dim,
|
||||
tp_size=tp_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_parameters=rope_parameters,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
warmup_iter=args.warmup_iter,
|
||||
benchmark_iter=args.benchmark_iter,
|
||||
csv_writer=csv_writer,
|
||||
)
|
||||
|
||||
print(f"Benchmark results saved to {csv_filename}")
|
||||
@@ -1,15 +1,25 @@
|
||||
import argparse
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -26,27 +36,20 @@ def main(
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
kv_cache_dtype: str | None = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
query = torch.empty(
|
||||
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
|
||||
|
||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||
max_seq_len = max(seq_lens)
|
||||
@@ -54,30 +57,38 @@ def main(
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
if not args.custom_paged_attn and not current_platform.is_navi():
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
@@ -97,7 +108,7 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
@@ -114,34 +125,58 @@ def main(
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
)
|
||||
if not args.custom_paged_attn:
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
@@ -157,39 +192,43 @@ def main(
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
if __name__ == "__main__":
|
||||
logger.warning(
|
||||
"This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference."
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--seq_len", type=int, default=4096)
|
||||
parser.add_argument("--seq-len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _triton_mode():
|
||||
"""Temporarily force the Triton fallback path"""
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
yield
|
||||
|
||||
|
||||
def _time_cuda(
|
||||
fn: Callable[[], tuple[torch.Tensor, torch.Tensor]],
|
||||
warmup_iters: int,
|
||||
bench_iters: int,
|
||||
) -> float:
|
||||
# warmup
|
||||
for _ in range(warmup_iters):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
for _ in range(bench_iters):
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start.elapsed_time(end) / bench_iters # ms/iter
|
||||
|
||||
|
||||
def _run_single(
|
||||
shape: tuple[int, int],
|
||||
group_size: int,
|
||||
dtype: str,
|
||||
*,
|
||||
column_major: bool = False,
|
||||
scale_ue8m0: bool = False,
|
||||
warmup_iters: int,
|
||||
bench_iters: int,
|
||||
) -> None:
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
device = torch.device("cuda")
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
if dtype == "fp8":
|
||||
|
||||
def cuda_impl():
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
|
||||
def triton_impl():
|
||||
with _triton_mode():
|
||||
return fp8_utils.per_token_group_quant_fp8(
|
||||
x,
|
||||
group_size,
|
||||
column_major_scales=column_major,
|
||||
use_ue8m0=scale_ue8m0,
|
||||
)
|
||||
elif dtype == "int8":
|
||||
|
||||
def cuda_impl():
|
||||
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||
|
||||
def triton_impl():
|
||||
with _triton_mode():
|
||||
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||
else:
|
||||
raise ValueError("dtype must be 'fp8' or 'int8'")
|
||||
|
||||
cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
|
||||
triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters)
|
||||
|
||||
speedup = triton_ms / cuda_ms if cuda_ms else math.inf
|
||||
|
||||
cfg_desc = (
|
||||
f"shape={shape} gs={group_size:<3} col_major={column_major:<5} "
|
||||
f"ue8m0={scale_ue8m0:<5} dtype={dtype}"
|
||||
)
|
||||
print(
|
||||
f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | "
|
||||
f"speed-up ×{speedup:5.2f}"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--warmup-iters", type=int, default=10)
|
||||
parser.add_argument("--bench-iters", type=int, default=100)
|
||||
parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if not current_platform.is_cuda():
|
||||
raise RuntimeError("CUDA device is required to run this benchmark.")
|
||||
|
||||
args = parse_args()
|
||||
warmup_iters, bench_iters = args.warmup_iters, args.bench_iters
|
||||
|
||||
shapes = [(32, 128), (64, 256), (16, 512)]
|
||||
group_sizes = [64, 128]
|
||||
|
||||
dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype]
|
||||
|
||||
header = (
|
||||
"Configuration".ljust(55)
|
||||
+ " | "
|
||||
+ "CUDA (ms)".center(12)
|
||||
+ " | "
|
||||
+ "Triton (ms)".center(13)
|
||||
+ " | "
|
||||
+ "Speed-up"
|
||||
)
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for dtype in dtypes:
|
||||
for shape in shapes:
|
||||
for gs in group_sizes:
|
||||
if dtype == "fp8":
|
||||
for col_major in (False, True):
|
||||
for ue8m0 in (False, True):
|
||||
_run_single(
|
||||
shape,
|
||||
gs,
|
||||
dtype,
|
||||
column_major=col_major,
|
||||
scale_ue8m0=ue8m0,
|
||||
warmup_iters=warmup_iters,
|
||||
bench_iters=bench_iters,
|
||||
)
|
||||
else: # INT8 has no col-major / ue8m0 switches
|
||||
_run_single(
|
||||
shape,
|
||||
gs,
|
||||
dtype,
|
||||
warmup_iters=warmup_iters,
|
||||
bench_iters=bench_iters,
|
||||
)
|
||||
109
benchmarks/kernels/benchmark_quant.py
Normal file
109
benchmarks/kernels/benchmark_quant.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
static_scale: bool,
|
||||
quant_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
if quant_dtype == torch.int8:
|
||||
ops.scaled_int8_quant(x, scale)
|
||||
else:
|
||||
ops.scaled_fp8_quant(x, scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported dtype: {dt}")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the quantization (fp8 or int8) kernel."
|
||||
)
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--static-scale", action="store_true")
|
||||
parser.add_argument(
|
||||
"--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument(
|
||||
"--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(
|
||||
num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
static_scale=args.static_scale,
|
||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
172
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
172
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
create_kv_caches_with_random,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
num_iters: int,
|
||||
benchmark_mode: str,
|
||||
device: str = "cuda",
|
||||
) -> float:
|
||||
"""Return latency (seconds) for given num_tokens."""
|
||||
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# create random key / value tensors [T, H, D].
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||
value = torch.randn_like(key)
|
||||
|
||||
# prepare the slot mapping.
|
||||
# each token is assigned a unique slot in the KV-cache.
|
||||
num_slots = block_size * num_blocks
|
||||
if num_tokens > num_slots:
|
||||
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
num_blocks,
|
||||
block_size,
|
||||
1, # num_layers
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
# to free unused memory
|
||||
del key_caches, value_caches
|
||||
|
||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
function_under_test = lambda: ops.reshape_and_cache(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if benchmark_mode == "cudagraph":
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
function_under_test = lambda: g.replay()
|
||||
|
||||
def run_cuda_benchmark(n_iters: int) -> float:
|
||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
return (end - start) / n_iters
|
||||
|
||||
# warm-up
|
||||
run_cuda_benchmark(3)
|
||||
|
||||
lat = run_cuda_benchmark(num_iters)
|
||||
|
||||
# free tensors to mitigate OOM when sweeping
|
||||
del key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return lat
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for exp in range(1, 17):
|
||||
n_tok = 2**exp
|
||||
lat = run_benchmark(
|
||||
num_tokens=n_tok,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
num_blocks=args.num_blocks,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
num_iters=args.iters,
|
||||
benchmark_mode=args.mode,
|
||||
device="cuda",
|
||||
)
|
||||
rows.append([n_tok, lat * 1e6]) # convert to microseconds
|
||||
|
||||
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
|
||||
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
|
||||
parser.add_argument("--num-heads", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--num-blocks", type=int, default=128 * 128)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="bfloat16",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
)
|
||||
|
||||
parser.add_argument("--iters", type=int, default=200)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["cudagraph", "no_graph"],
|
||||
default="cudagraph",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
210
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
210
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
create_kv_caches_with_random_flash,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_benchmark(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_layout: str,
|
||||
num_iters: int,
|
||||
implementation: str,
|
||||
benchmark_mode: str,
|
||||
device: str = "cuda",
|
||||
) -> float:
|
||||
"""Return latency (seconds) for given num_tokens."""
|
||||
|
||||
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||
|
||||
if implementation not in ("cuda", "triton"):
|
||||
raise ValueError(
|
||||
f"Unsupported implementation: {implementation}. "
|
||||
"Only 'cuda' and 'triton' are supported."
|
||||
)
|
||||
if implementation == "triton" and kv_cache_layout == "HND":
|
||||
return float("nan") # Triton does not support HND layout yet.
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device(device)
|
||||
|
||||
# create random key / value tensors [T, H, D].
|
||||
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||
value = torch.randn_like(key)
|
||||
|
||||
# prepare the slot mapping.
|
||||
# each token is assigned a unique slot in the KV-cache.
|
||||
num_slots = block_size * num_blocks
|
||||
if num_tokens > num_slots:
|
||||
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
key_caches, value_caches = create_kv_caches_with_random_flash(
|
||||
num_blocks,
|
||||
block_size,
|
||||
1, # num_layers
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device,
|
||||
cache_layout=kv_cache_layout,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
# to free unused memory
|
||||
del key_caches, value_caches
|
||||
|
||||
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||
|
||||
if implementation == "cuda":
|
||||
function_under_test = lambda: ops.reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
function_under_test = lambda: triton_reshape_and_cache_flash(
|
||||
key, # noqa: F821
|
||||
value, # noqa: F821
|
||||
key_cache, # noqa: F821
|
||||
value_cache, # noqa: F821
|
||||
slot_mapping, # noqa: F821
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if benchmark_mode == "cudagraph":
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
function_under_test = lambda: g.replay()
|
||||
|
||||
def run_cuda_benchmark(n_iters: int) -> float:
|
||||
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(n_iters):
|
||||
function_under_test()
|
||||
torch.cuda.synchronize()
|
||||
end = time.perf_counter()
|
||||
return (end - start) / n_iters
|
||||
|
||||
# warm-up
|
||||
run_cuda_benchmark(3)
|
||||
|
||||
lat = run_cuda_benchmark(num_iters)
|
||||
|
||||
# free tensors to mitigate OOM when sweeping
|
||||
del key, value, key_cache, value_cache, slot_mapping
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return lat
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for layout in ["NHD", "HND"]:
|
||||
for exp in range(1, 17):
|
||||
n_tok = 2**exp
|
||||
lat = run_benchmark(
|
||||
num_tokens=n_tok,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
num_blocks=args.num_blocks,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
kv_cache_layout=layout,
|
||||
num_iters=args.iters,
|
||||
implementation=args.implementation,
|
||||
benchmark_mode=args.mode,
|
||||
device="cuda",
|
||||
)
|
||||
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||
|
||||
print(
|
||||
f"Benchmark results for implementation {args.implementation}"
|
||||
f" (measuring with {args.mode}):"
|
||||
)
|
||||
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
|
||||
parser.add_argument("--num-heads", type=int, default=128)
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--num-blocks", type=int, default=128 * 512)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="bfloat16",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
default="auto",
|
||||
)
|
||||
|
||||
parser.add_argument("--iters", type=int, default=100)
|
||||
|
||||
parser.add_argument(
|
||||
"--implementation",
|
||||
type=str,
|
||||
choices=["cuda", "triton"],
|
||||
default="cuda",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
choices=["cudagraph", "no_graph"],
|
||||
default="cudagraph",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
255
benchmarks/kernels/benchmark_rmsnorm.py
Normal file
255
benchmarks/kernels/benchmark_rmsnorm.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from torch import nn
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
def rmsnorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||
naive_norm.weight = nn.Parameter(weight)
|
||||
naive_norm = naive_norm.to(x.device)
|
||||
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
output = naive_norm(x, residual)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_flashinfer(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
fused_add_rmsnorm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
output = rmsnorm(x, weight, eps)
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def rmsnorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
|
||||
if residual is not None:
|
||||
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||
output = (x, residual)
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.rms_norm(out, x, weight, eps)
|
||||
output = out
|
||||
|
||||
if isinstance(output, tuple):
|
||||
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||
else:
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
output_naive = rmsnorm_naive(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_flashinfer = rmsnorm_flashinfer(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
output_vllm = rmsnorm_vllm(
|
||||
x.clone(), weight, residual.clone() if residual is not None else None
|
||||
)
|
||||
|
||||
if use_residual:
|
||||
output_naive = output_naive[0]
|
||||
output_flashinfer = output_flashinfer[0]
|
||||
output_vllm = output_vllm[0]
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"vLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(
|
||||
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
head_num_range = [32, 48]
|
||||
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark(use_residual):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["head_num", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["huggingface", "flashinfer", "vllm"],
|
||||
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(head_num, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x) if use_residual else None
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "huggingface":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_naive(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_flashinfer(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rmsnorm_vllm(
|
||||
x.clone(),
|
||||
weight,
|
||||
residual.clone() if residual is not None else None,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-size",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Hidden size (2nd dimension) of the sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-residual", action="store_true", help="Whether to use residual connection"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/rmsnorm/",
|
||||
help="Path to save rmsnorm benchmark results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_size=args.hidden_size,
|
||||
use_residual=args.use_residual,
|
||||
)
|
||||
|
||||
# Get the benchmark function with proper use_residual setting
|
||||
benchmark = get_benchmark(args.use_residual)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@@ -1,121 +1,106 @@
|
||||
import argparse
|
||||
from itertools import accumulate
|
||||
from typing import Optional
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||
seq_len_range = [2**i for i in range(6, 10, 1)]
|
||||
num_heads_range = [32, 48]
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
|
||||
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
# silulating serving 4 LoRAs
|
||||
scaling_factors = [1, 2, 4, 8]
|
||||
# batched RoPE can take multiple scaling factors
|
||||
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, {
|
||||
"type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||
# instances to simulate the same behavior
|
||||
non_batched_ropes = []
|
||||
for scaling_factor in scaling_factors:
|
||||
non_batched_ropes.append(
|
||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
{
|
||||
"type": "linear",
|
||||
"factor": (scaling_factor, )
|
||||
}))
|
||||
def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len", "num_heads"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "flashinfer", "vllm"],
|
||||
line_names=["PyTorch", "FlashInfer", "vLLM"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||
dtype = torch.bfloat16
|
||||
max_position = 8192
|
||||
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=device)
|
||||
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
|
||||
query = torch.randn(
|
||||
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
|
||||
)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||
# together and each query needs to find the right kv cache of its type
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate([0] + [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
])))
|
||||
query_types = torch.randint(0,
|
||||
len(scaling_factors), (batch_size, seq_len),
|
||||
device=device)
|
||||
# map query types to offsets
|
||||
query_offsets = offset_map[query_types]
|
||||
# the kernel takes flattened offsets
|
||||
flatten_offsets = query_offsets.flatten()
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
# batched queries of the same type together for non-batched RoPE
|
||||
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
||||
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
||||
packed_qkr = zip(queries, keys, non_batched_ropes)
|
||||
# synchronize before start timing
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("non-batched", color="yellow"):
|
||||
for q, k, r in packed_qkr:
|
||||
r.forward(positions, q, k)
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("batched", color="green"):
|
||||
batched_rope.forward(positions, query, key, flatten_offsets)
|
||||
torch.cuda.synchronize()
|
||||
if provider == "torch":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rope.forward_native(positions, query.clone(), key.clone()),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "flashinfer":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
head_size,
|
||||
cos_sin_cache,
|
||||
is_neox_style,
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels.")
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||
parser.add_argument("--batch-size", type=int, default=16)
|
||||
parser.add_argument("--seq-len", type=int, default=512)
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["bfloat16", "float"],
|
||||
default="float")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style=args.is_neox_style,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
rotary_dim=args.rotary_dim,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
parser.add_argument(
|
||||
"--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
||||
)
|
||||
parser.add_argument("--save-path", type=str, default="./configs/rope/")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the benchmark function
|
||||
benchmark = get_benchmark(
|
||||
args.head_size, args.rotary_dim, args.is_neox_style, args.device
|
||||
)
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
|
||||
94
benchmarks/kernels/benchmark_shapes.py
Normal file
94
benchmarks/kernels/benchmark_shapes.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
WEIGHT_SHAPES = {
|
||||
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||
[4096, 6144],
|
||||
[4096, 4096],
|
||||
[4096, 28672],
|
||||
[14336, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||
[4096, 3072],
|
||||
[2048, 4096],
|
||||
[4096, 14336],
|
||||
[7168, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||
[4096, 1536],
|
||||
[1024, 4096],
|
||||
[4096, 7168],
|
||||
[3584, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||
[4096, 12288],
|
||||
[4096, 4096],
|
||||
[4096, 22016],
|
||||
[11008, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||
[4096, 6144],
|
||||
[2048, 4096],
|
||||
[4096, 11008],
|
||||
[5504, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||
[4096, 3072],
|
||||
[1024, 4096],
|
||||
[4096, 5504],
|
||||
[2752, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||
[5120, 15360],
|
||||
[5120, 5120],
|
||||
[5120, 27648],
|
||||
[13824, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||
[5120, 7680],
|
||||
[2560, 5120],
|
||||
[5120, 13824],
|
||||
[6912, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||
[5120, 3840],
|
||||
[1280, 5120],
|
||||
[5120, 6912],
|
||||
[3456, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||
[8192, 10240],
|
||||
[8192, 8192],
|
||||
[8192, 57344],
|
||||
[28672, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||
[8192, 5120],
|
||||
[4096, 8192],
|
||||
[8192, 28672],
|
||||
[14336, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||
[8192, 2560],
|
||||
[2048, 8192],
|
||||
[8192, 14336],
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1": [
|
||||
[8, 2, 4096, 28672],
|
||||
[8, 2, 14336, 4096],
|
||||
],
|
||||
"deepseek-ai/DeepSeek-V2-Lite": [
|
||||
[64, 6, 2048, 1408],
|
||||
],
|
||||
"ibm-granite/granite-3.0-1b-a400m": [
|
||||
[32, 8, 1024, 1024],
|
||||
],
|
||||
"ibm-granite/granite-3.0-3b-a800m": [
|
||||
[40, 8, 1024, 1536],
|
||||
],
|
||||
}
|
||||
720
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
720
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,720 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Comprehensive 3-way SiLU Benchmark Suite
|
||||
|
||||
This benchmark compares three SiLU implementations:
|
||||
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
|
||||
2. Triton Kernel - Triton-based implementation
|
||||
|
||||
The suite generates detailed performance comparisons including:
|
||||
- Memory bandwidth utilization
|
||||
- Speedup ratios (baseline vs optimized implementations)
|
||||
- Performance across different expert configurations and token distributions
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||
mask = cols < BLOCK
|
||||
|
||||
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
|
||||
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
gate = tl.load(
|
||||
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
|
||||
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if use_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens,
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
expert_offsets: torch.Tensor = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = (H + group_size - 1) // group_size
|
||||
assert H % group_size == 0, "H must be divisible by group_size"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
|
||||
"tokens_per_expert must be shape (E,)"
|
||||
)
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||
|
||||
# allocate outputs
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G,)
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_deep_gemm_e8m0_used(),
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
# Parse generation strategies
|
||||
strategies = ["random_imbalanced", "uniform", "max_t"]
|
||||
|
||||
|
||||
def benchmark(
|
||||
kernel: Callable,
|
||||
E: int,
|
||||
T: int,
|
||||
H: int,
|
||||
total_tokens: int,
|
||||
num_parallel_tokens: int = 64,
|
||||
G: int = 128,
|
||||
runs: int = 200,
|
||||
num_warmups: int = 20,
|
||||
gen_strategy: str = "default",
|
||||
iterations_per_run: int = 20,
|
||||
):
|
||||
def generate_data(seed_offset=0):
|
||||
"""Generate input data with given seed offset"""
|
||||
current_platform.seed_everything(42 + seed_offset)
|
||||
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
if gen_strategy == "random_imbalanced":
|
||||
|
||||
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
|
||||
mean = total_tokens // n_e
|
||||
min_max = mean // ratio
|
||||
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
|
||||
e[0] = min_max
|
||||
r = torch.rand(size=(E - 1,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens - min_max
|
||||
r = r.round().long()
|
||||
e[1:] = r.to(device=device)
|
||||
return e
|
||||
|
||||
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
|
||||
elif gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,))
|
||||
r /= r.sum()
|
||||
r *= total_tokens
|
||||
r = r.round().long()
|
||||
tokens_per_expert = r
|
||||
elif gen_strategy == "max_t":
|
||||
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert.fill_(total_tokens / E)
|
||||
elif gen_strategy == "first_t":
|
||||
tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert[0] = min(T, total_tokens)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation strategy: {gen_strategy}")
|
||||
return y, tokens_per_expert
|
||||
|
||||
dataset_count = 4
|
||||
# Pre-generate different input matrices for each iteration to avoid cache effects
|
||||
data_sets = [generate_data(i) for i in range(dataset_count)]
|
||||
|
||||
# Warmup
|
||||
y, tokens_per_expert = data_sets[0]
|
||||
for _ in range(num_warmups):
|
||||
kernel(
|
||||
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
# Benchmark
|
||||
latencies: list[float] = []
|
||||
for _ in range(runs):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
for i in range(iterations_per_run):
|
||||
y, tokens_per_expert = data_sets[i % dataset_count]
|
||||
kernel(
|
||||
y,
|
||||
tokens_per_expert,
|
||||
num_parallel_tokens=num_parallel_tokens,
|
||||
group_size=G,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
total_time_ms = start_event.elapsed_time(end_event)
|
||||
per_iter_time_ms = total_time_ms / iterations_per_run
|
||||
latencies.append(per_iter_time_ms)
|
||||
|
||||
# Use median instead of average for better outlier handling
|
||||
median_time_ms = np.median(latencies)
|
||||
median_time_s = median_time_ms / 1000
|
||||
|
||||
# Calculate actual work done (using first dataset for consistency)
|
||||
_, tokens_per_expert = data_sets[0]
|
||||
actual_tokens = tokens_per_expert.sum().item()
|
||||
actual_elements = actual_tokens * H
|
||||
|
||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||
ops_per_element = 8
|
||||
total_ops = actual_elements * ops_per_element
|
||||
gflops = total_ops / median_time_s / 1e9
|
||||
|
||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||
memory_bw = total_bytes / median_time_s / 1e9
|
||||
|
||||
HOPPER_BANDWIDTH_TBPS = 3.35
|
||||
return (
|
||||
median_time_ms,
|
||||
gflops,
|
||||
memory_bw,
|
||||
(memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100,
|
||||
)
|
||||
|
||||
|
||||
def create_comparison_plot(
|
||||
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
|
||||
):
|
||||
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.25
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
|
||||
ax.bar(
|
||||
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
triton_v2_speedup = ratios[i][1] # triton/v2
|
||||
max_height = max(silu_v2_times[i], triton_times[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
ax.set_ylabel("% Utilization")
|
||||
ax.set_title(
|
||||
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig, ax
|
||||
|
||||
|
||||
def create_combined_plot(all_results):
|
||||
num_strategies = len(all_results)
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
|
||||
|
||||
if num_strategies == 1:
|
||||
axes = [axes]
|
||||
|
||||
for idx, (
|
||||
strategy_name,
|
||||
all_ratios,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) in enumerate(all_results):
|
||||
ax = axes[idx]
|
||||
|
||||
# Flatten the nested results to get bandwidth percentages for plotting
|
||||
silu_v2_bandwidths = []
|
||||
triton_bandwidths = []
|
||||
flat_ratios = []
|
||||
|
||||
for config_results in all_silu_v2_results:
|
||||
for result in config_results:
|
||||
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_results in all_triton_results:
|
||||
for result in config_results:
|
||||
triton_bandwidths.append(result[3]) # bandwidth percentage
|
||||
|
||||
for config_ratios in all_ratios:
|
||||
for ratio in config_ratios:
|
||||
flat_ratios.append(ratio)
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.25
|
||||
|
||||
# Bandwidth utilization plot (higher is better)
|
||||
ax.bar(
|
||||
x,
|
||||
silu_v2_bandwidths,
|
||||
width,
|
||||
label="SiLU V2 (CUDA)",
|
||||
alpha=0.8,
|
||||
color="blue",
|
||||
)
|
||||
ax.bar(
|
||||
x + width,
|
||||
triton_bandwidths,
|
||||
width,
|
||||
label="Triton Kernel",
|
||||
alpha=0.8,
|
||||
color="green",
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar trio
|
||||
for i in range(len(x)):
|
||||
triton_v2_speedup = flat_ratios[i] # triton/v2
|
||||
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
|
||||
|
||||
# Triton/V2 speedup
|
||||
ax.text(
|
||||
x[i] + width / 2,
|
||||
max_height + max_height * 0.02,
|
||||
f"{triton_v2_speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
ax.set_ylabel("% Utilization")
|
||||
ax.set_title(
|
||||
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
filename = "silu_benchmark_combined_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
outer_dim = 7168
|
||||
configs = [
|
||||
# DeepSeekV3 Configs
|
||||
# (1, 56, 7168),
|
||||
(8, 1024, 7168),
|
||||
# (32, 56, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(32, 1024, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 1024, 7168),
|
||||
]
|
||||
|
||||
runs = 100
|
||||
num_warmups = 20
|
||||
|
||||
strategy_descriptions = {
|
||||
"uniform": "Uniform Random",
|
||||
"random_imbalanced": "Imbalanced Random",
|
||||
"max_t": "Even Assignment",
|
||||
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||
}
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"Testing strategies: {', '.join(strategies)}")
|
||||
print(f"Configurations: {len(configs)} configs")
|
||||
|
||||
all_results = []
|
||||
|
||||
# Run benchmarks for each strategy
|
||||
for id, strategy in enumerate(strategies):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# Collect benchmark data for all three algorithms
|
||||
config_labels = []
|
||||
config_x_axis = []
|
||||
all_silu_v2_results = []
|
||||
all_triton_results = []
|
||||
all_ratios = []
|
||||
|
||||
for E, T, H in configs:
|
||||
total_tokens_config = []
|
||||
for i in [8, 16, 32, 64, 128, 256, 512]:
|
||||
if i <= T:
|
||||
total_tokens_config.append(i * E)
|
||||
config_x_axis.append(total_tokens_config)
|
||||
|
||||
silu_v2_results = []
|
||||
triton_results = []
|
||||
ratios = []
|
||||
|
||||
for total_tokens in total_tokens_config:
|
||||
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||
config_labels.append(config_label)
|
||||
|
||||
# SiLU V2 (CUDA kernel) results
|
||||
time_ms_silu_v2, gflops, gbps, perc = benchmark(
|
||||
persistent_masked_m_silu_mul_quant,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
total_tokens,
|
||||
runs=runs,
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
|
||||
|
||||
# Triton kernel results
|
||||
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_triton,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
total_tokens,
|
||||
runs=runs,
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
triton_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
|
||||
# Calculate speedup ratios (triton baseline / implementation)
|
||||
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
|
||||
ratios.append(triton_v2_ratio)
|
||||
|
||||
print(
|
||||
f"Completed: {config_label}:"
|
||||
f" V2: {time_ms_silu_v2:.3f}ms,"
|
||||
f" Triton: {time_ms_triton:.3f}ms"
|
||||
)
|
||||
|
||||
all_silu_v2_results.append(silu_v2_results)
|
||||
all_triton_results.append(triton_results)
|
||||
all_ratios.append(ratios)
|
||||
|
||||
# Store results for combined plotting
|
||||
all_results.append(
|
||||
(
|
||||
strategy_descriptions[strategy],
|
||||
all_ratios,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
)
|
||||
)
|
||||
|
||||
# Print summary table for this strategy
|
||||
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
|
||||
print("-" * 90)
|
||||
|
||||
for i, (E, T, H) in enumerate(configs):
|
||||
# Get the first result for each config (simplifying for summary)
|
||||
v2_time = silu_v2_results[i][0]
|
||||
triton_time = triton_results[i][0]
|
||||
triton_v2_speedup = triton_time / v2_time
|
||||
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||
print(
|
||||
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
|
||||
f"{triton_v2_speedup:8.2f}x"
|
||||
)
|
||||
|
||||
|
||||
def create_total_tokens_plot(all_results):
|
||||
num_strategies = len(all_results)
|
||||
num_configs = len(configs)
|
||||
|
||||
fig, axs = plt.subplots(
|
||||
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
|
||||
)
|
||||
|
||||
# Add main title to the entire figure
|
||||
fig.suptitle(
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
|
||||
fontsize=18,
|
||||
fontweight="bold",
|
||||
y=0.98,
|
||||
)
|
||||
|
||||
# Handle single strategy case
|
||||
if num_strategies == 1:
|
||||
axs = axs.reshape(1, -1)
|
||||
|
||||
# Handle single config case
|
||||
if num_configs == 1:
|
||||
axs = axs.reshape(-1, 2)
|
||||
|
||||
for strategy_idx, result in enumerate(all_results):
|
||||
(
|
||||
strategy_name,
|
||||
all_ratios,
|
||||
all_silu_v2_results,
|
||||
all_triton_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) = result
|
||||
|
||||
for config_idx in range(num_configs):
|
||||
# Speedup plot (left column)
|
||||
ax_speedup = axs[strategy_idx, config_idx * 2]
|
||||
# Bandwidth plot (right column)
|
||||
ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1]
|
||||
|
||||
E, T, H = configs[config_idx]
|
||||
ratios = all_ratios[config_idx]
|
||||
total_tokens_values = config_x_axis[config_idx]
|
||||
|
||||
# Extract speedup ratios
|
||||
triton_v2_ratios = [ratio for ratio in ratios]
|
||||
|
||||
# Extract bandwidth percentages for all implementations
|
||||
v2_bandwidth_percentages = [
|
||||
result[3] for result in all_silu_v2_results[config_idx]
|
||||
]
|
||||
triton_bandwidth_percentages = [
|
||||
result[3] for result in all_triton_results[config_idx]
|
||||
]
|
||||
|
||||
# Plot speedup ratios vs total tokens (left plot)
|
||||
ax_speedup.plot(
|
||||
total_tokens_values,
|
||||
triton_v2_ratios,
|
||||
"go-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton/V2 Speedup",
|
||||
)
|
||||
ax_speedup.set_title(
|
||||
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||
ax_speedup.legend(prop={"weight": "bold"})
|
||||
ax_speedup.grid(True, alpha=0.3)
|
||||
|
||||
# Plot bandwidth utilization (right plot)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
v2_bandwidth_percentages,
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="SiLU V2",
|
||||
color="blue",
|
||||
)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
triton_bandwidth_percentages,
|
||||
"o-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton",
|
||||
color="green",
|
||||
)
|
||||
ax_bandwidth.set_title(
|
||||
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_bandwidth.set_ylabel(
|
||||
"% of Peak Bandwidth", fontweight="bold", fontsize=11
|
||||
)
|
||||
ax_bandwidth.legend(prop={"weight": "bold"})
|
||||
ax_bandwidth.grid(True, alpha=0.3)
|
||||
|
||||
# Format x-axis labels for both plots
|
||||
for ax in [ax_speedup, ax_bandwidth]:
|
||||
ax.set_xticks(total_tokens_values)
|
||||
ax.set_xticklabels(
|
||||
[
|
||||
f"{tt // 1000}K" if tt >= 1000 else str(tt)
|
||||
for tt in total_tokens_values
|
||||
],
|
||||
fontweight="bold",
|
||||
)
|
||||
# Make tick labels bold
|
||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||
label.set_fontweight("bold")
|
||||
|
||||
# Add value labels on Triton/V2 speedup points
|
||||
for x, y in zip(total_tokens_values, triton_v2_ratios):
|
||||
ax_speedup.annotate(
|
||||
f"{y:.2f}x",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, -15),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3),
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||
filename = "silu_benchmark_total_tokens_3way.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
# Create comprehensive 3-way comparison plots
|
||||
combined_plot_filename = create_combined_plot(all_results)
|
||||
total_tokens_plot_filename = create_total_tokens_plot(all_results)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("3-Way Benchmark Suite Complete!")
|
||||
print(f"Generated combined comparison plot: {combined_plot_filename}")
|
||||
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
|
||||
print("Compared: SiLU V2 (CUDA), and Triton implementations")
|
||||
print(f"{'=' * 80}")
|
||||
290
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
290
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_decode(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
305
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
305
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_prefill(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
max_q_len = max_kv_len = max_seq_len
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(
|
||||
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||
)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout
|
||||
)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.Event(enable_timing=True)
|
||||
end = torch.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
|
||||
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_prefill(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
415
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
415
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from sglang quantization/tuning_block_wise_kernel.py
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
assert current_platform.is_cuda(), (
|
||||
"Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||
)
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
config: dict[str, Any],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with
|
||||
block-wise quantization.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization.
|
||||
It should be 2-dim, e.g., [128, 128].
|
||||
output_dtype: The dtype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
if A.dtype == torch.float8_e4m3fn:
|
||||
kernel = _w8a8_triton_block_scaled_mm
|
||||
else:
|
||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
|
||||
# Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
(2112, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(12288, 7168),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
):
|
||||
def run():
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# JIT complication & warmup
|
||||
for _ in range(5):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.Event(enable_timing=True)
|
||||
end_event = torch.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
run()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
return avg
|
||||
|
||||
|
||||
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
if input_type == "fp8":
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
else:
|
||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype,
|
||||
num_iters=10,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def save_configs(
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
configs,
|
||||
save_path,
|
||||
input_type="fp8",
|
||||
) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
json_file_name = (
|
||||
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||
f"block_shape=[{block_n},{block_k}].json"
|
||||
)
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
|
||||
with open(config_file_path, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def tune_on_gpu(args_dict):
|
||||
"""Run tuning on a specific GPU."""
|
||||
gpu_id = args_dict["gpu_id"]
|
||||
batch_sizes = args_dict["batch_sizes"]
|
||||
weight_shapes = args_dict["weight_shapes"]
|
||||
args = args_dict["args"]
|
||||
|
||||
torch.cuda.set_device(gpu_id)
|
||||
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||
|
||||
block_n = args.block_n
|
||||
block_k = args.block_k
|
||||
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||
save_path = args.save_path
|
||||
input_type = args.input_type
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
start = time.time()
|
||||
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||
N, K = shape[0], shape[1]
|
||||
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||
benchmark_results = [
|
||||
tune(
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
[block_n, block_k],
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
)
|
||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||
|
||||
end = time.time()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||
"""Distribute batch sizes across available GPUs."""
|
||||
batches_per_gpu = []
|
||||
for i in range(num_gpus):
|
||||
start_idx = i * len(batch_sizes) // num_gpus
|
||||
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||
return batches_per_gpu
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("No GPU available for tuning")
|
||||
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||
|
||||
torch.cuda.init()
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
num_gpus = 1 # If only one batch size, use only one GPU
|
||||
|
||||
weight_shapes = get_weight_shapes(args.tp_size)
|
||||
|
||||
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||
|
||||
process_args = []
|
||||
for gpu_id in range(num_gpus):
|
||||
process_args.append(
|
||||
{
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
}
|
||||
)
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(num_gpus) as pool:
|
||||
pool.map(tune_on_gpu, process_args)
|
||||
|
||||
print("Multi-GPU tuning completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
||||
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||
Then copy to model_executor/layers/quantization/utils/configs
|
||||
""",
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
choices=["float32", "float16", "bfloat16", "half"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument("--block-n", type=int, default=128)
|
||||
parser.add_argument("--block-k", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--save-path", type=str, default="./")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
129
benchmarks/kernels/deepgemm/README.md
Normal file
129
benchmarks/kernels/deepgemm/README.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# DeepSeek DeepGEMM Kernels Benchmark
|
||||
|
||||
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
|
||||
|
||||
Currently, this just includes dense GEMMs and only works on Hopper GPUs.
|
||||
|
||||
## Setup
|
||||
|
||||
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
|
||||
|
||||
```bash
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
|
||||
cd DeepGEMM
|
||||
python setup.py install
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```console
|
||||
python benchmark_fp8_block_dense_gemm.py
|
||||
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
|
||||
===== STARTING FP8 GEMM BENCHMARK =====
|
||||
PyTorch version: 2.5.1+cu124
|
||||
CUDA version: 12.4
|
||||
Triton version: 3.1.0
|
||||
Using device: NVIDIA H100 80GB HBM3
|
||||
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
|
||||
===== PERFORMANCE COMPARISON =====
|
||||
|
||||
DeepGEMM Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
|
||||
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
|
||||
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
|
||||
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
|
||||
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
|
||||
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
|
||||
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
|
||||
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
|
||||
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
|
||||
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
|
||||
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
|
||||
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
|
||||
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
|
||||
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
|
||||
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
|
||||
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
|
||||
vLLM Triton Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
|
||||
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
|
||||
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
|
||||
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
|
||||
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
|
||||
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
|
||||
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
|
||||
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
|
||||
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
|
||||
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
|
||||
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
|
||||
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
|
||||
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
|
||||
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
|
||||
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
|
||||
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
|
||||
vLLM CUTLASS Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
|
||||
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
|
||||
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
|
||||
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
|
||||
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
|
||||
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
|
||||
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
|
||||
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
|
||||
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
|
||||
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
|
||||
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
|
||||
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
|
||||
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
|
||||
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
|
||||
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
|
||||
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
|
||||
===== AVERAGE PERFORMANCE =====
|
||||
+----------------+------------+----------+---------------+
|
||||
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
|
||||
+----------------+------------+----------+---------------+
|
||||
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
|
||||
| vLLM Triton | 144.30 | 715.60 | 0.23 |
|
||||
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
|
||||
+----------------+------------+----------+---------------+
|
||||
|
||||
===== AVERAGE SPEEDUPS =====
|
||||
+-----------------------------+--------------+
|
||||
| Comparison | Speedup |
|
||||
+-----------------------------+--------------+
|
||||
| DeepGEMM vs vLLM Triton | 1.71x faster |
|
||||
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
|
||||
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
|
||||
+-----------------------------+--------------+
|
||||
|
||||
===== ACCURACY COMPARISON =====
|
||||
+----------------+-----------------------+
|
||||
| Implementation | Avg Diff vs Reference |
|
||||
+----------------+-----------------------+
|
||||
| DeepGEMM | 0.000684 |
|
||||
| vLLM Triton | 0.000684 |
|
||||
| vLLM CUTLASS | 0.000684 |
|
||||
+----------------+-----------------------+
|
||||
```
|
||||
435
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
435
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
calc_diff,
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_shape(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False,
|
||||
) -> dict:
|
||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||
if verbose:
|
||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||
|
||||
# Create test tensors
|
||||
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference result in BF16
|
||||
torch.cuda.synchronize()
|
||||
C_ref = A @ B.t()
|
||||
|
||||
# Pre-quantize B for all implementations
|
||||
# (weights can be pre-quantized offline)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||
|
||||
# Block size configuration
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
fp8_gemm_nt(
|
||||
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
|
||||
)
|
||||
return C_deepgemm
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
return w8a8_triton_block_scaled_mm(
|
||||
A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
return ops.cutlass_scaled_mm(
|
||||
A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Run correctness check first
|
||||
if verbose:
|
||||
print("Running correctness check...")
|
||||
C_deepgemm = deepgemm_gemm()
|
||||
C_vllm_triton = vllm_triton_gemm()
|
||||
C_vllm_cutlass = vllm_cutlass_gemm()
|
||||
|
||||
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
||||
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
||||
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
||||
|
||||
if verbose:
|
||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||
print(
|
||||
"vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
|
||||
)
|
||||
print(
|
||||
"vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
|
||||
)
|
||||
|
||||
# Benchmark implementations
|
||||
implementations = {
|
||||
"DeepGEMM": deepgemm_gemm,
|
||||
"vLLM Triton": vllm_triton_gemm,
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm,
|
||||
}
|
||||
|
||||
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
|
||||
|
||||
for name, func in implementations.items():
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing loop
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
# Calculate timing and TFLOPS
|
||||
avg_time_ms = (end - start) / repeat * 1000
|
||||
avg_time_us = avg_time_ms * 1000
|
||||
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
||||
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
||||
|
||||
benchmark_results["implementations"][name] = {
|
||||
"time_ms": avg_time_ms,
|
||||
"time_us": avg_time_us,
|
||||
"tflops": tflops,
|
||||
"gb_s": gb_s,
|
||||
"diff": {
|
||||
"DeepGEMM": 0.0
|
||||
if name == "DeepGEMM"
|
||||
else calc_diff(func(), C_deepgemm),
|
||||
"Reference": deepgemm_diff
|
||||
if name == "DeepGEMM"
|
||||
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
|
||||
},
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
|
||||
|
||||
# Calculate speedups
|
||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||
for name, data in benchmark_results["implementations"].items():
|
||||
if name != "DeepGEMM":
|
||||
speedup = baseline / data["time_ms"]
|
||||
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
|
||||
if verbose:
|
||||
print(
|
||||
f"DeepGEMM is {1 / speedup:.2f}x "
|
||||
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
|
||||
)
|
||||
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
|
||||
cutlass_vs_triton
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
||||
)
|
||||
|
||||
return benchmark_results
|
||||
|
||||
|
||||
def format_table_row(values, widths):
|
||||
"""Format a row with specified column widths."""
|
||||
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
|
||||
|
||||
|
||||
def print_table(headers, rows, title=None):
|
||||
"""Print a table with headers and rows."""
|
||||
if title:
|
||||
print(f"\n{title}")
|
||||
|
||||
# Calculate column widths based on headers and data
|
||||
widths = [
|
||||
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
|
||||
# Create separator line
|
||||
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
||||
|
||||
# Print table
|
||||
print(separator)
|
||||
print(format_table_row(headers, widths))
|
||||
print(separator)
|
||||
for row in rows:
|
||||
print(format_table_row(row, widths))
|
||||
print(separator)
|
||||
|
||||
|
||||
def format_speedup(value):
|
||||
"""Format speedup value with indicator if it's faster or slower."""
|
||||
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
||||
|
||||
|
||||
def run_benchmarks(verbose: bool = False):
|
||||
"""Run benchmarks for a set of common shapes."""
|
||||
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
||||
|
||||
# Make sure we're using the GPU
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available! Tests require GPU.")
|
||||
return
|
||||
|
||||
# Print system information
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
print(f"Triton version: {triton.__version__}")
|
||||
print(f"Using device: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Enable TF32 for better performance
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Set seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
# Define benchmark shapes (m, n, k)
|
||||
shapes = [
|
||||
(8, 4096, 7168),
|
||||
(8, 7168, 18432),
|
||||
(8, 18432, 7168),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 18432),
|
||||
(64, 18432, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 18432),
|
||||
(128, 18432, 7168),
|
||||
(1024, 4096, 7168),
|
||||
(1024, 18432, 7168),
|
||||
(2048, 4096, 7168),
|
||||
(4096, 4096, 7168),
|
||||
]
|
||||
shapes = [
|
||||
# (64, 2112, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 2048),
|
||||
# (128, 2112, 7168),
|
||||
(128, 24576, 1536),
|
||||
(128, 32768, 512),
|
||||
(128, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 2048),
|
||||
# (4096, 2112, 7168),
|
||||
(4096, 24576, 1536),
|
||||
(4096, 32768, 512),
|
||||
(4096, 7168, 16384),
|
||||
(4096, 4096, 7168),
|
||||
(4096, 7168, 2048),
|
||||
]
|
||||
|
||||
all_results = []
|
||||
for m, n, k in shapes:
|
||||
result = benchmark_shape(m, n, k, verbose=verbose)
|
||||
all_results.append(result)
|
||||
|
||||
# Print results in a nicely formatted table
|
||||
print("\n===== PERFORMANCE COMPARISON =====")
|
||||
|
||||
# Print DeepGEMM table
|
||||
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
||||
deepgemm_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["DeepGEMM"]
|
||||
deepgemm_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
]
|
||||
)
|
||||
|
||||
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
|
||||
|
||||
# Print vLLM Triton table
|
||||
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
|
||||
triton_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM Triton"]
|
||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
triton_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup),
|
||||
]
|
||||
)
|
||||
|
||||
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
|
||||
|
||||
# Print vLLM CUTLASS table
|
||||
cutlass_headers = [
|
||||
"m",
|
||||
"n",
|
||||
"k",
|
||||
"Time (μs)",
|
||||
"TFLOPS",
|
||||
"GB/s",
|
||||
"vs DeepGEMM",
|
||||
"vs Triton",
|
||||
]
|
||||
cutlass_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||
cutlass_rows.append(
|
||||
[
|
||||
shape["m"],
|
||||
shape["n"],
|
||||
shape["k"],
|
||||
f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}",
|
||||
f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton),
|
||||
]
|
||||
)
|
||||
|
||||
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
|
||||
|
||||
# Calculate and print averages
|
||||
print("\n===== AVERAGE PERFORMANCE =====")
|
||||
|
||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||
avg_metrics = {
|
||||
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
impl_data = result["implementations"][impl]
|
||||
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
||||
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
||||
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
||||
|
||||
num_shapes = len(all_results)
|
||||
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
||||
avg_rows = []
|
||||
|
||||
for impl in implementations:
|
||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||
avg_rows.append(
|
||||
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
|
||||
)
|
||||
|
||||
print_table(avg_headers, avg_rows)
|
||||
|
||||
# Calculate average speedups
|
||||
avg_speedups = {
|
||||
"DeepGEMM vs vLLM Triton": 0,
|
||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||
"vLLM CUTLASS vs vLLM Triton": 0,
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||
|
||||
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
|
||||
vllm_triton_time / vllm_cutlass_time
|
||||
)
|
||||
|
||||
print("\n===== AVERAGE SPEEDUPS =====")
|
||||
speedup_headers = ["Comparison", "Speedup"]
|
||||
speedup_rows = []
|
||||
for comparison, total in avg_speedups.items():
|
||||
avg_speedup = total / num_shapes
|
||||
status = "faster" if avg_speedup > 1 else "slower"
|
||||
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
||||
|
||||
print_table(speedup_headers, speedup_rows)
|
||||
|
||||
# Average accuracy comparison
|
||||
print("\n===== ACCURACY COMPARISON =====")
|
||||
avg_diff = {impl: 0 for impl in implementations}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
|
||||
|
||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||
diff_rows = []
|
||||
for impl in implementations:
|
||||
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
||||
|
||||
print_table(diff_headers, diff_rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_benchmarks(verbose=False)
|
||||
64
benchmarks/kernels/graph_machete_bench.py
Normal file
64
benchmarks/kernels/graph_machete_bench.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import regex as re
|
||||
import seaborn as sns
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("filename", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
raw_results: list[TMeasurement] = data["results"]
|
||||
|
||||
results = defaultdict(lambda: list())
|
||||
for v in raw_results:
|
||||
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||
if result is not None:
|
||||
KN = result.group(1)
|
||||
else:
|
||||
raise Exception("MKN not found")
|
||||
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
|
||||
if result is not None:
|
||||
M = result.group(1)
|
||||
else:
|
||||
raise Exception("MKN not found")
|
||||
|
||||
kernel = v.task_spec.description
|
||||
results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
|
||||
|
||||
rows = int(math.ceil(len(results) / 2))
|
||||
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||
axs = axs.flatten()
|
||||
for axs_idx, (shape, data) in enumerate(results.items()):
|
||||
plt.sca(axs[axs_idx])
|
||||
df = pd.DataFrame(data)
|
||||
sns.lineplot(
|
||||
data=df,
|
||||
x="batch_size",
|
||||
y="median",
|
||||
hue="kernel",
|
||||
style="kernel",
|
||||
markers=True,
|
||||
dashes=False,
|
||||
palette="Dark2",
|
||||
)
|
||||
plt.title(f"Shape: {shape}")
|
||||
plt.ylabel("time (median, s)")
|
||||
plt.tight_layout()
|
||||
plt.savefig("graph_machete_bench.pdf")
|
||||
1
benchmarks/kernels/requirements.txt
Normal file
1
benchmarks/kernels/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
pandas
|
||||
214
benchmarks/kernels/utils.py
Normal file
214
benchmarks/kernels/utils.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CudaGraphBenchParams:
|
||||
num_ops_in_cuda_graph: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgPool:
|
||||
"""
|
||||
When some argument of the benchmarking function is annotated with this type,
|
||||
the benchmarking class (BenchMM) will collapse the argument to a pick a
|
||||
single value from the given list of values, during function invocation.
|
||||
For every invocation during a benchmarking run, it will choose a
|
||||
different value from the list.
|
||||
"""
|
||||
|
||||
values: Iterable[Any]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.values[index]
|
||||
|
||||
|
||||
class Bench:
|
||||
class ArgsIterator:
|
||||
def __init__(self, args_list, kwargs_list):
|
||||
assert len(args_list) == len(kwargs_list)
|
||||
self.args_list = args_list
|
||||
self.kwargs_list = kwargs_list
|
||||
self.n = len(self.args_list)
|
||||
self.idx = 0
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
|
||||
self.idx += 1
|
||||
self.idx = self.idx % self.n
|
||||
|
||||
def reset(self):
|
||||
self.idx = 0
|
||||
|
||||
@property
|
||||
def n_args(self):
|
||||
return self.n
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cuda_graph_params: CudaGraphBenchParams | None,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
description: str,
|
||||
fn: Callable,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.cuda_graph_params = cuda_graph_params
|
||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||
self.label = label
|
||||
self.sub_label = sub_label
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
|
||||
# Process args
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
|
||||
|
||||
# Cudagraph runner
|
||||
self.g = None
|
||||
if self.use_cuda_graph:
|
||||
self.g = self.get_cuda_graph_runner()
|
||||
|
||||
# benchmark run params
|
||||
self.min_run_time = 1
|
||||
|
||||
def collapse_argpool(self, *args, **kwargs):
|
||||
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
|
||||
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
|
||||
]
|
||||
if len(argpool_args) == 0:
|
||||
return [args], [kwargs]
|
||||
|
||||
# Make sure all argpools are of the same size
|
||||
argpool_size = len(argpool_args[0].values)
|
||||
assert all([argpool_size == len(arg.values) for arg in argpool_args])
|
||||
|
||||
# create copies of the args
|
||||
args_list = []
|
||||
kwargs_list = []
|
||||
for _ in range(argpool_size):
|
||||
args_list.append(args)
|
||||
kwargs_list.append(kwargs.copy())
|
||||
|
||||
for i in range(argpool_size):
|
||||
# collapse args; Just pick the ith value
|
||||
args_list[i] = tuple(
|
||||
[arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
|
||||
)
|
||||
|
||||
# collapse kwargs
|
||||
kwargs_i = kwargs_list[i]
|
||||
arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
|
||||
for k in arg_pool_keys:
|
||||
# again just pick the ith value
|
||||
kwargs_i[k] = kwargs_i[k][i]
|
||||
kwargs_list[i] = kwargs_i
|
||||
|
||||
return args_list, kwargs_list
|
||||
|
||||
def get_cuda_graph_runner(self):
|
||||
assert self.use_cuda_graph
|
||||
assert self.args_iterator is not None
|
||||
|
||||
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
|
||||
|
||||
# warmup
|
||||
args_it = self.args_iterator.__next__()
|
||||
for _ in range(2):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for _ in range(num_graph_ops):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
return g
|
||||
|
||||
def run_cudagrah(self) -> TMeasurement:
|
||||
assert self.use_cuda_graph
|
||||
globals = {"g": self.g}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt="g.replay()",
|
||||
globals=globals,
|
||||
label=(
|
||||
f"{self.label}"
|
||||
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
|
||||
),
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run_eager(self) -> TMeasurement:
|
||||
setup = None
|
||||
stmt = None
|
||||
globals = None
|
||||
|
||||
has_arg_pool = self.args_iterator.n_args > 1
|
||||
if has_arg_pool:
|
||||
setup = """
|
||||
args_iterator.reset()
|
||||
args_it = args_iterator.__next__()
|
||||
"""
|
||||
stmt = """
|
||||
args, kwargs = next(args_it)
|
||||
fn(*args, **kwargs)
|
||||
"""
|
||||
globals = {"fn": self.fn, "args_iterator": self.args_iterator}
|
||||
else:
|
||||
# no arg pool. Just use the args and kwargs directly
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
args, kwargs = next(args_it)
|
||||
|
||||
setup = ""
|
||||
stmt = """
|
||||
fn(*args, **kwargs)
|
||||
"""
|
||||
globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
globals=globals,
|
||||
label=self.label,
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run(self) -> TMeasurement:
|
||||
timer = None
|
||||
if self.use_cuda_graph: # noqa SIM108
|
||||
timer = self.run_cudagrah()
|
||||
else:
|
||||
timer = self.run_eager()
|
||||
if not timer.meets_confidence() or timer.has_warnings:
|
||||
print("Doesn't meet confidence - re-running bench ...")
|
||||
return self.run()
|
||||
return timer
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
print(f"exc type {exc_type}")
|
||||
print(f"exc value {exc_value}")
|
||||
print(f"exc traceback {traceback}")
|
||||
104
benchmarks/kernels/weight_shapes.py
Normal file
104
benchmarks/kernels/weight_shapes.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"mistralai/Mistral-7B-v0.1": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf": [
|
||||
([4096, 12288], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 22016], 1),
|
||||
([11008, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3-8b": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf": [
|
||||
([5120, 15360], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 27648], 1),
|
||||
([13824, 5120], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.1-405b-hf": [
|
||||
([16384, 18432], 1),
|
||||
([16384, 16384], 0),
|
||||
([16384, 106496], 1),
|
||||
([53248, 16384], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
"mistralai/Mistral-Large-Instruct-2407": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 57344], 1),
|
||||
([28672, 12288], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-7B-Instruct": [
|
||||
([3584, 4608], 1),
|
||||
([3584, 3584], 0),
|
||||
([3584, 37888], 1),
|
||||
([18944, 3584], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-32B-Instruct": [
|
||||
([5120, 7168], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 55296], 1),
|
||||
([27648, 5120], 0),
|
||||
],
|
||||
"Qwen/Qwen2.5-72B-Instruct": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 59136], 1),
|
||||
([29568, 8192], 0),
|
||||
],
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||
([2048, 3072], 1),
|
||||
([2048, 4096], 1),
|
||||
([2048, 2048], 0),
|
||||
([2048, 576], 0),
|
||||
([2048, 21888], 1),
|
||||
([10944, 2048], 0),
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
"CohereLabs/c4ai-command-a-03-2025": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 73728], 1),
|
||||
([36864, 12288], 0),
|
||||
],
|
||||
}
|
||||
Reference in New Issue
Block a user