Speed up expert location update (#6661)
This commit is contained in:
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import logging
|
import logging
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -348,14 +347,8 @@ def update_expert_weights_single_layer(
|
|||||||
return
|
return
|
||||||
|
|
||||||
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
|
reqs = torch.distributed.batch_isend_irecv(p2p_ops)
|
||||||
try:
|
for req in reqs:
|
||||||
for req in reqs:
|
req.wait()
|
||||||
req.wait(timeout=timedelta(seconds=30))
|
|
||||||
except RuntimeError:
|
|
||||||
logger.error(
|
|
||||||
f"Context: {rank=} {old_physical_to_logical_map=} {new_physical_to_logical_map=} {num_local_physical_experts=} {num_gpu_per_node=}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
|
def _execute_buffer2weight_copies(buffer2weight_copy_infos):
|
||||||
for (
|
for (
|
||||||
|
|||||||
Reference in New Issue
Block a user