[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)
This commit is contained in:
@@ -172,12 +172,11 @@ class ServerArgs:
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
moe_a2a_backend: Optional[Literal["deepep"]] = None
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
@@ -272,7 +271,27 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Check deprecated arguments
|
||||
def print_deprecated_warning(message: str):
|
||||
logger.warning(f"\033[33m{message}\033[0m")
|
||||
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
|
||||
)
|
||||
if self.enable_deepep_moe:
|
||||
self.moe_a2a_backend = "deepep"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
||||
)
|
||||
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -455,14 +474,13 @@ class ServerArgs:
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.warning(
|
||||
f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
assert self.ep_size in [
|
||||
1,
|
||||
self.tp_size,
|
||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.moe_a2a_backend == "deepep":
|
||||
if self.deepep_mode == "normal":
|
||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||
self.disable_cuda_graph = True
|
||||
@@ -486,7 +504,7 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
assert self.enable_ep_moe or self.enable_deepep_moe
|
||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
||||
|
||||
if self.enable_expert_distribution_metrics and (
|
||||
self.expert_distribution_recorder_mode is None
|
||||
@@ -1354,30 +1372,27 @@ class ServerArgs:
|
||||
help="The expert parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
"--moe-a2a-backend",
|
||||
type=str,
|
||||
choices=["deepep"],
|
||||
default=ServerArgs.moe_a2a_backend,
|
||||
help="Choose the backend for MoE A2A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-cutlass-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
help="Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-mode",
|
||||
type=str,
|
||||
@@ -1839,6 +1854,18 @@ class ServerArgs:
|
||||
help="Disable mmap while loading weight using safetensors.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
Reference in New Issue
Block a user