[Fix] Add per_channel_quant parameter to MoE config functions (#11201)
This commit is contained in:
committed by
GitHub
parent
516738b096
commit
c7867b6702
@@ -16,14 +16,19 @@ _is_hip = is_hip()
|
||||
|
||||
|
||||
def get_config_file_name(
|
||||
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
|
||||
E: int,
|
||||
N: int,
|
||||
dtype: Optional[str],
|
||||
block_shape: Optional[int] = None,
|
||||
per_channel_quant: bool = False,
|
||||
) -> str:
|
||||
device_name = get_device_name().replace(" ", "_")
|
||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||
block_shape_selector = (
|
||||
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
||||
)
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
|
||||
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
|
||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@@ -33,6 +38,7 @@ def get_moe_configs(
|
||||
dtype: Optional[str],
|
||||
block_n: Optional[int] = 0,
|
||||
block_k: Optional[int] = 0,
|
||||
per_channel_quant: bool = False,
|
||||
) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
@@ -47,7 +53,9 @@ def get_moe_configs(
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
|
||||
json_file_name = get_config_file_name(
|
||||
E, N, dtype, [block_n, block_k], per_channel_quant
|
||||
)
|
||||
|
||||
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
||||
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
||||
|
||||
Reference in New Issue
Block a user