Tuning Script for Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3922)

Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
laixin
2025-02-27 18:59:46 +08:00
committed by GitHub
parent 3e02526b1f
commit b0df5d240b
16 changed files with 2129 additions and 28 deletions

View File

@@ -41,13 +41,14 @@ def benchmark_config(
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int] = None,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
if use_int8_w8a16 or use_int8_w8a8:
w1 = torch.randint(
-127,
127,
@@ -86,7 +87,7 @@ def benchmark_config(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
if use_fp8_w8a8 or use_int8_w8a8:
if block_shape is None:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
@@ -105,6 +106,7 @@ def benchmark_config(
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
@@ -126,6 +128,7 @@ def benchmark_config(
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -235,6 +238,7 @@ class BenchmarkWorker:
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> Tuple[Dict[str, int], float]:
@@ -270,6 +274,7 @@ class BenchmarkWorker:
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
@@ -284,6 +289,7 @@ class BenchmarkWorker:
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
search_space: List[Dict[str, int]],
@@ -301,6 +307,7 @@ class BenchmarkWorker:
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
num_iters=10,
@@ -340,11 +347,15 @@ def save_configs(
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: List[int],
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
@@ -396,6 +407,7 @@ def main(args: argparse.Namespace):
hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a8 = args.dtype == "int8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_shape = None
if (
@@ -467,6 +479,7 @@ def main(args: argparse.Namespace):
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
search_space,
@@ -485,6 +498,7 @@ def main(args: argparse.Namespace):
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
@@ -502,6 +516,7 @@ def main(args: argparse.Namespace):
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
block_shape,
)
@@ -521,7 +536,10 @@ if __name__ == "__main__":
)
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
"--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
default="auto",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)