[Fix] incorrect assert in EPLB (#7575)

This commit is contained in:
Cheng Wan
2025-06-26 14:59:20 -07:00
committed by GitHub
parent bb9b608c86
commit 1b8cf77b01

View File

@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
topk_ids = topk_ids.flatten()
mask = topk_ids != -1
assert self._data[layer_idx, :].shape == topk_ids.shape, (
"Shape mismatch between data and topk_ids."
"Selecting expert is not supported for multiple token prediction at the moment."
)
self._data[layer_idx, :].scatter_add_(
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
)