Tuning Script for Feature DeepSeek V3/R1 INT8 Quantization (block-wise) (#3922)
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
This commit is contained in:
@@ -41,13 +41,14 @@ def benchmark_config(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int] = None,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
if use_int8_w8a16:
|
||||
if use_int8_w8a16 or use_int8_w8a8:
|
||||
w1 = torch.randint(
|
||||
-127,
|
||||
127,
|
||||
@@ -86,7 +87,7 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
if block_shape is None:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
@@ -105,6 +106,7 @@ def benchmark_config(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||
|
||||
@@ -126,6 +128,7 @@ def benchmark_config(
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
@@ -235,6 +238,7 @@ class BenchmarkWorker:
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
@@ -270,6 +274,7 @@ class BenchmarkWorker:
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
@@ -284,6 +289,7 @@ class BenchmarkWorker:
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
search_space: List[Dict[str, int]],
|
||||
@@ -301,6 +307,7 @@ class BenchmarkWorker:
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
num_iters=10,
|
||||
@@ -340,11 +347,15 @@ def save_configs(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: List[int],
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
@@ -396,6 +407,7 @@ def main(args: argparse.Namespace):
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a8 = args.dtype == "int8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
block_shape = None
|
||||
if (
|
||||
@@ -467,6 +479,7 @@ def main(args: argparse.Namespace):
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
search_space,
|
||||
@@ -485,6 +498,7 @@ def main(args: argparse.Namespace):
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
@@ -502,6 +516,7 @@ def main(args: argparse.Namespace):
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_shape,
|
||||
)
|
||||
@@ -521,7 +536,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
"--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16", "int8_w8a8"],
|
||||
default="auto",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
|
||||
Reference in New Issue
Block a user