add dsv3 mi300 triton config for block scale (#3146)
This commit is contained in:
@@ -18,6 +18,9 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
get_default_config,
|
||||
get_moe_configs,
|
||||
)
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
_is_hip_ = is_hip()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@@ -102,8 +105,8 @@ def benchmark_config(
|
||||
(num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32
|
||||
)
|
||||
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
@@ -165,17 +168,15 @@ def benchmark_config(
|
||||
return avg
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
def get_rocm_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
waves_per_eu_range = 0
|
||||
for num_stages in [2]:
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for block_n in [16, 32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
for group_size in [1, 4, 8, 16, 32]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
@@ -184,11 +185,39 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
configs: List[BenchmarkConfig] = []
|
||||
if _is_hip_:
|
||||
configs = get_rocm_configs_compute_bound()
|
||||
else:
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
@@ -297,6 +326,9 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
**(
|
||||
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user