EPLB support for MTP (#7510)

This commit is contained in:
yilian49
2025-06-25 01:16:56 -07:00
committed by GitHub
parent 7b9a174a7a
commit 587b4c6e92
2 changed files with 28 additions and 4 deletions

View File

@@ -61,6 +61,10 @@ class ExpertDistributionRecorder(ABC):
def with_debug_name(self, debug_name):
yield
@contextmanager
def disable_this_region(self):
yield
@contextmanager
def with_forward_pass(self, forward_pass_id: int, forward_batch: ForwardBatch):
yield
@@ -116,6 +120,7 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
self._expert_location_metadata = expert_location_metadata
self._recording = False
self._disable_all = False
self._current_forward_pass_id = Withable()
self._current_layer_idx = Withable()
self._current_debug_name = Withable()
@@ -148,6 +153,16 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
finally:
self._on_forward_pass_end(forward_pass_id)
@contextmanager
def disable_this_region(self):
"""Context manager to temporarily disable recording."""
previous_disable_all = self._disable_all
self._disable_all = True
try:
yield
finally:
self._disable_all = previous_disable_all
def _on_forward_pass_start(self, forward_batch: ForwardBatch):
if not self._recording:
return
@@ -189,6 +204,8 @@ class _ExpertDistributionRecorderReal(ExpertDistributionRecorder):
)
def _on_hook(self, hook_name: str, **kwargs):
if self._disable_all:
return
if not (self._recording or torch.cuda.is_current_stream_capturing()):
return
gatherer = self._single_pass_gatherers[
@@ -462,6 +479,10 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
assert self._data[layer_idx, :].shape == topk_ids.shape, (
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)