[Fix] incorrect assert in EPLB (#7575)
This commit is contained in:
@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
|
|||||||
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
|
||||||
topk_ids = topk_ids.flatten()
|
topk_ids = topk_ids.flatten()
|
||||||
mask = topk_ids != -1
|
mask = topk_ids != -1
|
||||||
assert self._data[layer_idx, :].shape == topk_ids.shape, (
|
|
||||||
"Shape mismatch between data and topk_ids."
|
|
||||||
"Selecting expert is not supported for multiple token prediction at the moment."
|
|
||||||
)
|
|
||||||
self._data[layer_idx, :].scatter_add_(
|
self._data[layer_idx, :].scatter_add_(
|
||||||
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user