Tuning Script for Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3922)

Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
laixin
2025-02-27 18:59:46 +08:00
committed by GitHub
parent 3e02526b1f
commit b0df5d240b
16 changed files with 2129 additions and 28 deletions

View File

@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
_w8a8_block_fp8_matmul,
_w8a8_block_fp8_matmul_unrolledx4,
)
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
@@ -42,7 +43,7 @@ DTYPE_MAP = {
}
def w8a8_block_fp8_matmul(
def w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
@@ -94,11 +95,15 @@ def w8a8_block_fp8_matmul(
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
else:
kernel = _w8a8_block_int8_matmul
kernel[grid](
A,
@@ -208,10 +213,10 @@ def get_weight_shapes(tp_size):
def benchmark_config(
A_fp8, B_fp8, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
):
def run():
w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, config, out_dtype)
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
# JIT complication & warmup
@@ -234,20 +239,41 @@ def benchmark_config(
return avg
def tune(M, N, K, block_size, out_dtype, search_space):
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
if input_type == "fp8":
fp8_info = torch.finfo(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
)
else:
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
)
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n
@@ -264,8 +290,8 @@ def tune(M, N, K, block_size, out_dtype, search_space):
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
A_fp8,
B_fp8,
A,
B,
As,
Bs,
block_size,
@@ -293,10 +319,11 @@ def save_configs(
block_k,
configs,
save_path,
input_type="fp8",
) -> None:
os.makedirs(save_path, exist_ok=True)
device_name = get_device_name().replace(" ", "_")
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...")
@@ -325,6 +352,7 @@ def tune_on_gpu(args_dict):
block_k = args.block_k
out_dtype = DTYPE_MAP[args.out_dtype]
save_path = args.save_path
input_type = args.input_type
search_space = get_configs_compute_bound()
search_space = [
@@ -337,11 +365,19 @@ def tune_on_gpu(args_dict):
N, K = shape[0], shape[1]
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
benchmark_results = [
tune(batch_size, N, K, [block_n, block_k], out_dtype, search_space)
tune(
batch_size,
N,
K,
[block_n, block_k],
out_dtype,
search_space,
input_type,
)
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
]
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
save_configs(N, K, block_n, block_k, best_configs, save_path)
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
@@ -418,6 +454,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument(
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
)
parser.add_argument(
"--out-dtype",
type=str,