Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468)
This commit is contained in:
@@ -67,9 +67,9 @@ class DeepEPBuffer:
|
||||
if deepep_mode.enable_normal():
|
||||
hidden_bytes = hidden_size * param_bytes
|
||||
for config in (
|
||||
_DeepEPConfig.get_instance().normal_dispatch_config
|
||||
DeepEPConfig.get_instance().normal_dispatch_config
|
||||
or Buffer.get_dispatch_config(group.size()),
|
||||
_DeepEPConfig.get_instance().normal_combine_config
|
||||
DeepEPConfig.get_instance().normal_combine_config
|
||||
or Buffer.get_combine_config(group.size()),
|
||||
):
|
||||
num_nvl_bytes = max(
|
||||
@@ -97,7 +97,12 @@ class DeepEPBuffer:
|
||||
num_nvl_bytes,
|
||||
num_rdma_bytes,
|
||||
low_latency_mode=deepep_mode.enable_low_latency(),
|
||||
num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
|
||||
num_qps_per_rank=(
|
||||
max(
|
||||
num_experts // group.size(),
|
||||
DeepEPConfig.get_instance().num_sms // 2,
|
||||
)
|
||||
),
|
||||
)
|
||||
return cls._buffer
|
||||
|
||||
@@ -122,7 +127,7 @@ class DeepEPBuffer:
|
||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||
|
||||
|
||||
class _DeepEPConfig:
|
||||
class DeepEPConfig:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
@@ -131,16 +136,23 @@ class _DeepEPConfig:
|
||||
config_parsed = load_json_config(config_str)
|
||||
if torch.distributed.get_rank() == 0:
|
||||
logger.info(f"Use DeepEP Config: {config_parsed}")
|
||||
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
|
||||
self.normal_combine_config = Config(**config_parsed["normal_combine"])
|
||||
config_dispatch = config_parsed["normal_dispatch"]
|
||||
config_combine = config_parsed["normal_combine"]
|
||||
|
||||
self.normal_dispatch_config = Config(**config_dispatch)
|
||||
self.normal_combine_config = Config(**config_combine)
|
||||
|
||||
assert config_dispatch["num_sms"] == config_combine["num_sms"]
|
||||
self.num_sms = config_dispatch["num_sms"]
|
||||
else:
|
||||
self.normal_dispatch_config = None
|
||||
self.normal_combine_config = None
|
||||
self.num_sms = Buffer.num_sms
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = _DeepEPConfig()
|
||||
cls._instance = DeepEPConfig()
|
||||
return cls._instance
|
||||
|
||||
|
||||
@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
async_finish=self.async_finish,
|
||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
||||
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
||||
)
|
||||
|
||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||
@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
async_finish=self.async_finish,
|
||||
previous_event=previous_event,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
config=_DeepEPConfig.get_instance().normal_combine_config,
|
||||
config=DeepEPConfig.get_instance().normal_combine_config,
|
||||
)
|
||||
return combined_x, event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user