[Fix] incorrect assert in EPLB (#7575)
This commit is contained in:
@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||
topk_ids = topk_ids.flatten()
|
||||
mask = topk_ids != -1
|
||||
assert self._data[layer_idx, :].shape == topk_ids.shape, (
|
||||
"Shape mismatch between data and topk_ids."
|
||||
"Selecting expert is not supported for multiple token prediction at the moment."
|
||||
)
|
||||
self._data[layer_idx, :].scatter_add_(
|
||||
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user