forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
846
vllm/perf/benchmark_moe.py
Normal file
846
vllm/perf/benchmark_moe.py
Normal file
@@ -0,0 +1,846 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, TypedDict, Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
# 移除全局的 current_platform 导入,改为在需要时局部导入
|
||||
# FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
BLOCK_SIZE_N: int
|
||||
BLOCK_SIZE_K: int
|
||||
GROUP_SIZE_M: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
num_ldmatrixes: Optional[int]
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
config: BenchmarkConfig,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
nn_moe: Optional[bool] = False
|
||||
) -> float:
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
|
||||
if use_int8_w8a16:
|
||||
if not nn_moe:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
w2 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
(
|
||||
num_experts,
|
||||
shard_intermediate_size // 2,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
if not nn_moe:
|
||||
w1 = torch.randn(
|
||||
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device
|
||||
)
|
||||
else:
|
||||
w1 = torch.randn(
|
||||
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype, device=device
|
||||
)
|
||||
w2 = torch.randn(
|
||||
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype, device=device
|
||||
)
|
||||
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device)
|
||||
|
||||
w1_scale = None
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device)
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
E = num_experts
|
||||
N = shard_intermediate_size // 2
|
||||
K = hidden_size
|
||||
factor_for_scale = 1e-2
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
w1_scale = (
|
||||
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device)
|
||||
* factor_for_scale
|
||||
)
|
||||
w2_scale = (
|
||||
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device)
|
||||
* factor_for_scale
|
||||
)
|
||||
else:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
|
||||
|
||||
a1_scale = torch.randn(1, dtype=torch.float32, device=device)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32, device=device)
|
||||
|
||||
# 获取 FP8_DTYPE
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device)
|
||||
|
||||
def prepare(i: int):
|
||||
input_gating.copy_(gating_output[i])
|
||||
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
allow_deep_gemm=True,
|
||||
use_nn_moe=nn_moe,
|
||||
)
|
||||
else:
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
use_nn_moe=nn_moe,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture 10 invocations with CUDA graph
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(10):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Warmup
|
||||
for _ in range(5):
|
||||
graph.replay()
|
||||
# run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event.record()
|
||||
graph.replay()
|
||||
# run()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
graph.reset()
|
||||
return avg
|
||||
|
||||
|
||||
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [32, 64, 128, 256]
|
||||
if not use_fp16:
|
||||
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||
num_warps_range = [2, 4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
# waves_per_eu_range = [0]
|
||||
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||
# kpack_range = [1, 2] if use_fp16 else []
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
# "waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
if nn_moe:
|
||||
param_ranges["num_ldmatrixes"] = [1]
|
||||
|
||||
# DCU currently does not support the following parameters
|
||||
# if use_fp16:
|
||||
# param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||
# param_ranges["kpack"] = kpack_range
|
||||
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
|
||||
configs: list[BenchmarkConfig] = []
|
||||
|
||||
# 局部导入 current_platform
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
|
||||
else:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
}
|
||||
|
||||
keys, values = zip(*param_ranges.items())
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
|
||||
# Remove configs that are not compatible with fp8 block quantization
|
||||
# BLOCK_SIZE_K must be a multiple of block_k
|
||||
# BLOCK_SIZE_N must be a multiple of block_n
|
||||
if block_quant_shape is not None and not use_fp16:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
for config in configs[:]:
|
||||
if (
|
||||
config["BLOCK_SIZE_K"] % block_k != 0
|
||||
or config["BLOCK_SIZE_N"] % block_n != 0
|
||||
):
|
||||
configs.remove(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(
|
||||
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
||||
):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(
|
||||
num_tokens * topk, N1, K1, search_space, is_fp16
|
||||
)
|
||||
pruned_space_2 = prune_rocm_configs(
|
||||
num_tokens * topk, N2, K2, search_space, is_fp16
|
||||
)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
|
||||
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 2 if is_fp16 else 1
|
||||
elemBytes_b = 2 if is_fp16 else 1
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
|
||||
# DCU currently does not support matrix_instr_nonkdim param
|
||||
# if is_fp16:
|
||||
# matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
# if matrix_instr_nonkdim > mfma:
|
||||
# continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
|
||||
# DCU currently does not support matrix_instr_nonkdim param
|
||||
# if is_fp16:
|
||||
# if (
|
||||
# matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
# or matrix_instr_nonkdim > BLOCK_SIZE_N
|
||||
# ):
|
||||
# continue
|
||||
# if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||
# continue
|
||||
# if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||
# continue
|
||||
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||
)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def merge_unique_dicts(list1, list2):
|
||||
result = []
|
||||
combined_list = list1.copy()
|
||||
combined_list.extend(list2)
|
||||
for dictionary in combined_list:
|
||||
if dictionary not in result:
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
def __init__(self, seed: int, device_id: int) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
import os
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# In ROCm environment with Ray, let Ray handle device assignment
|
||||
# Don't manually set default device as it may conflict with Ray's device mapping
|
||||
pass
|
||||
else:
|
||||
torch.set_default_device("cuda:"+ str(device_id))
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Store the logical device ID for Ray
|
||||
self.device_id = device_id
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int] = None,
|
||||
use_deep_gemm: bool = False,
|
||||
nn_moe: Optional[bool] = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
# 局部导入 current_platform
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.seed_everything(self.seed)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, get_moe_configs, get_default_config
|
||||
)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
op_config = get_moe_configs(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, use_nn_moe=nn_moe
|
||||
)
|
||||
if op_config is None:
|
||||
config = get_default_config(
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False,
|
||||
use_nn_moe=nn_moe
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
use_nn_moe=nn_moe
|
||||
)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
use_deep_gemm: bool,
|
||||
nn_moe: Optional[bool] = False,
|
||||
) -> dict[str, int]:
|
||||
from vllm.platforms import current_platform
|
||||
import os
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(
|
||||
num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
search_space,
|
||||
is_fp16,
|
||||
topk,
|
||||
)
|
||||
|
||||
# In ROCm environments with Ray, device context is already handled by Ray
|
||||
# Using torch.cuda.device() may cause device ordinal conflicts
|
||||
need_device_guard = False
|
||||
if current_platform.is_rocm():
|
||||
# For ROCm with Ray, skip additional device context management
|
||||
need_device_guard = False
|
||||
else:
|
||||
# For other platforms, use device guard if needed
|
||||
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if visible_devices is not None and len(visible_devices.split(',')) > 1:
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
nn_moe=nn_moe)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"num_ldmatrixes": config["num_ldmatrixes"]} if "num_ldmatrixes" in config else {}
|
||||
),
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
**(
|
||||
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
||||
if "matrix_instr_nonkdim" in config
|
||||
else {}
|
||||
),
|
||||
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int],
|
||||
use_nn_moe: Optional[bool] = False,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, get_config_file_name
|
||||
)
|
||||
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape, use_nn_moe=use_nn_moe
|
||||
)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
quantization_config = getattr(config, "quantization_config", {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get("weight_block_size", default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
import os
|
||||
import logging
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
print(args)
|
||||
|
||||
tp_size = args.tp_size
|
||||
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
||||
if args.model_prefix:
|
||||
config = getattr(config, args.model_prefix)
|
||||
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", "Glm4MoeForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
elif config.architectures[0] in ("Step3VLForConditionalGeneration"):
|
||||
E = config.text_config.moe_num_experts
|
||||
topk = config.text_config.moe_top_k
|
||||
intermediate_size = config.text_config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // tp_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
|
||||
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||
logger.warning(
|
||||
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
||||
)
|
||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||
|
||||
ray.init(address=None, ignore_reinit_error=True, num_gpus=args.num_gpus)
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed, i) for i in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
worker = workers[worker_idx]
|
||||
worker_method = getattr(worker, method)
|
||||
output = worker_method.remote(*input_args)
|
||||
outputs.append(output)
|
||||
worker_idx = (worker_idx + 1) % num_gpus
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape, args.nn_moe)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
search_space,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
args.nn_moe,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
best_configs = {
|
||||
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(
|
||||
best_configs,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
use_nn_moe=args.nn_moe,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark",
|
||||
[
|
||||
(
|
||||
batch_size,
|
||||
E,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
use_deep_gemm,
|
||||
args.nn_moe,
|
||||
)
|
||||
for batch_size in batch_sizes
|
||||
],
|
||||
)
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
print(f"Kernel time: {kernel_time:.2f} us")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--nn-moe", action='store_true', default=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
parser.add_argument("--num-gpus", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user