[Fix] Add per_channel_quant parameter to MoE config functions (#11201)
This commit is contained in:
committed by
GitHub
parent
516738b096
commit
c7867b6702
@@ -16,14 +16,19 @@ _is_hip = is_hip()
|
|||||||
|
|
||||||
|
|
||||||
def get_config_file_name(
|
def get_config_file_name(
|
||||||
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
|
E: int,
|
||||||
|
N: int,
|
||||||
|
dtype: Optional[str],
|
||||||
|
block_shape: Optional[int] = None,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
device_name = get_device_name().replace(" ", "_")
|
device_name = get_device_name().replace(" ", "_")
|
||||||
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
dtype_selector = "" if not dtype else f",dtype={dtype}"
|
||||||
block_shape_selector = (
|
block_shape_selector = (
|
||||||
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
|
||||||
)
|
)
|
||||||
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
|
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
|
||||||
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
@@ -33,6 +38,7 @@ def get_moe_configs(
|
|||||||
dtype: Optional[str],
|
dtype: Optional[str],
|
||||||
block_n: Optional[int] = 0,
|
block_n: Optional[int] = 0,
|
||||||
block_k: Optional[int] = 0,
|
block_k: Optional[int] = 0,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
) -> Optional[Dict[int, Any]]:
|
) -> Optional[Dict[int, Any]]:
|
||||||
"""
|
"""
|
||||||
Return optimized configurations for the fused MoE kernel.
|
Return optimized configurations for the fused MoE kernel.
|
||||||
@@ -47,7 +53,9 @@ def get_moe_configs(
|
|||||||
|
|
||||||
# First look up if an optimized configuration is available in the configs
|
# First look up if an optimized configuration is available in the configs
|
||||||
# directory
|
# directory
|
||||||
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
|
json_file_name = get_config_file_name(
|
||||||
|
E, N, dtype, [block_n, block_k], per_channel_quant
|
||||||
|
)
|
||||||
|
|
||||||
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
|
||||||
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
|
||||||
|
|||||||
Reference in New Issue
Block a user