Speed up expert location update (#6661)
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -348,14 +347,8 @@ def update_expert_weights_single_layer(
|
||||
return
|
||||
|
||||
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
|
||||
try:
|
||||
for req in reqs:
|
||||
req.wait(timeout=timedelta(seconds=30))
|
||||
except RuntimeError:
|
||||
logger.error(
|
||||
f"Context: {rank=} {old_physical_to_logical_map=} {new_physical_to_logical_map=} {num_local_physical_experts=} {num_gpu_per_node=}"
|
||||
)
|
||||
raise
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
|
||||
for (
|
||||
|
||||
Reference in New Issue
Block a user