Tuning Script for Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3922)
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
_w8a8_block_fp8_matmul_unrolledx4,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
|
||||
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
@@ -42,7 +43,7 @@ DTYPE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
def w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@@ -94,11 +95,15 @@ def w8a8_block_fp8_matmul(
|
||||
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
|
||||
N, config["BLOCK_SIZE_N"]
|
||||
)
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
|
||||
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
|
||||
kernel = (
|
||||
_w8a8_block_fp8_matmul_unrolledx4
|
||||
if (is_hip_ == True and num_workgroups <= get_device_core_count())
|
||||
else _w8a8_block_fp8_matmul
|
||||
)
|
||||
else:
|
||||
kernel = _w8a8_block_int8_matmul
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
@@ -208,10 +213,10 @@ def get_weight_shapes(tp_size):
|
||||
|
||||
|
||||
def benchmark_config(
|
||||
A_fp8, B_fp8, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||
):
|
||||
def run():
|
||||
w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, config, out_dtype)
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# JIT complication & warmup
|
||||
@@ -234,20 +239,41 @@ def benchmark_config(
|
||||
return avg
|
||||
|
||||
|
||||
def tune(M, N, K, block_size, out_dtype, search_space):
|
||||
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
)
|
||||
if input_type == "fp8":
|
||||
fp8_info = torch.finfo(
|
||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
)
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
|
||||
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
)
|
||||
else:
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||
)
|
||||
A = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * int8_max
|
||||
)
|
||||
B = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
@@ -264,8 +290,8 @@ def tune(M, N, K, block_size, out_dtype, search_space):
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
A_fp8,
|
||||
B_fp8,
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
@@ -293,10 +319,11 @@ def save_configs(
|
||||
block_k,
|
||||
configs,
|
||||
save_path,
|
||||
input_type="fp8",
|
||||
) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
device_name = get_device_name().replace(" ", "_")
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
||||
json_file_name = f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,block_shape=[{block_n}, {block_k}].json"
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
@@ -325,6 +352,7 @@ def tune_on_gpu(args_dict):
|
||||
block_k = args.block_k
|
||||
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||
save_path = args.save_path
|
||||
input_type = args.input_type
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
@@ -337,11 +365,19 @@ def tune_on_gpu(args_dict):
|
||||
N, K = shape[0], shape[1]
|
||||
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||
benchmark_results = [
|
||||
tune(batch_size, N, K, [block_n, block_k], out_dtype, search_space)
|
||||
tune(
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
[block_n, block_k],
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
)
|
||||
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path)
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||
|
||||
end = time.time()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
@@ -418,6 +454,9 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
Reference in New Issue
Block a user