Fix num_qps_per_rank computation when providing custom DeepEP configuration (#6468)
This commit is contained in:
@@ -67,9 +67,9 @@ class DeepEPBuffer:
|
|||||||
if deepep_mode.enable_normal():
|
if deepep_mode.enable_normal():
|
||||||
hidden_bytes = hidden_size * param_bytes
|
hidden_bytes = hidden_size * param_bytes
|
||||||
for config in (
|
for config in (
|
||||||
_DeepEPConfig.get_instance().normal_dispatch_config
|
DeepEPConfig.get_instance().normal_dispatch_config
|
||||||
or Buffer.get_dispatch_config(group.size()),
|
or Buffer.get_dispatch_config(group.size()),
|
||||||
_DeepEPConfig.get_instance().normal_combine_config
|
DeepEPConfig.get_instance().normal_combine_config
|
||||||
or Buffer.get_combine_config(group.size()),
|
or Buffer.get_combine_config(group.size()),
|
||||||
):
|
):
|
||||||
num_nvl_bytes = max(
|
num_nvl_bytes = max(
|
||||||
@@ -97,7 +97,12 @@ class DeepEPBuffer:
|
|||||||
num_nvl_bytes,
|
num_nvl_bytes,
|
||||||
num_rdma_bytes,
|
num_rdma_bytes,
|
||||||
low_latency_mode=deepep_mode.enable_low_latency(),
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
||||||
num_qps_per_rank=(max(num_experts // group.size(), Buffer.num_sms // 2)),
|
num_qps_per_rank=(
|
||||||
|
max(
|
||||||
|
num_experts // group.size(),
|
||||||
|
DeepEPConfig.get_instance().num_sms // 2,
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return cls._buffer
|
return cls._buffer
|
||||||
|
|
||||||
@@ -122,7 +127,7 @@ class DeepEPBuffer:
|
|||||||
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
||||||
|
|
||||||
|
|
||||||
class _DeepEPConfig:
|
class DeepEPConfig:
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -131,16 +136,23 @@ class _DeepEPConfig:
|
|||||||
config_parsed = load_json_config(config_str)
|
config_parsed = load_json_config(config_str)
|
||||||
if torch.distributed.get_rank() == 0:
|
if torch.distributed.get_rank() == 0:
|
||||||
logger.info(f"Use DeepEP Config: {config_parsed}")
|
logger.info(f"Use DeepEP Config: {config_parsed}")
|
||||||
self.normal_dispatch_config = Config(**config_parsed["normal_dispatch"])
|
config_dispatch = config_parsed["normal_dispatch"]
|
||||||
self.normal_combine_config = Config(**config_parsed["normal_combine"])
|
config_combine = config_parsed["normal_combine"]
|
||||||
|
|
||||||
|
self.normal_dispatch_config = Config(**config_dispatch)
|
||||||
|
self.normal_combine_config = Config(**config_combine)
|
||||||
|
|
||||||
|
assert config_dispatch["num_sms"] == config_combine["num_sms"]
|
||||||
|
self.num_sms = config_dispatch["num_sms"]
|
||||||
else:
|
else:
|
||||||
self.normal_dispatch_config = None
|
self.normal_dispatch_config = None
|
||||||
self.normal_combine_config = None
|
self.normal_combine_config = None
|
||||||
|
self.num_sms = Buffer.num_sms
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
def get_instance(cls):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = _DeepEPConfig()
|
cls._instance = DeepEPConfig()
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
@@ -326,7 +338,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
||||||
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
||||||
config=_DeepEPConfig.get_instance().normal_dispatch_config,
|
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
||||||
@@ -433,7 +445,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
async_finish=self.async_finish,
|
async_finish=self.async_finish,
|
||||||
previous_event=previous_event,
|
previous_event=previous_event,
|
||||||
allocate_on_comm_stream=previous_event is not None,
|
allocate_on_comm_stream=previous_event is not None,
|
||||||
config=_DeepEPConfig.get_instance().normal_combine_config,
|
config=DeepEPConfig.get_instance().normal_combine_config,
|
||||||
)
|
)
|
||||||
return combined_x, event
|
return combined_x, event
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user