Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -19,6 +19,8 @@ from torch.distributed import (
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -249,10 +251,18 @@ def move_to_buffer(
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
if isinstance(get_ep_group(), StatelessGroupCoordinator):
|
||||
ep_group = get_ep_group()
|
||||
is_stateless = True
|
||||
else:
|
||||
is_stateless = False
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
# Pre-compute global ranks mapping (only needed for non-stateless groups)
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
if not is_stateless:
|
||||
rank_to_global = {
|
||||
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
|
||||
}
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
@@ -284,15 +294,23 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
if is_stateless:
|
||||
for w in expert_weights:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = w[src]
|
||||
op.group_peer = dst
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
w[src],
|
||||
dst_global,
|
||||
)
|
||||
for w in expert_weights
|
||||
]
|
||||
|
||||
# 3. Post recvs
|
||||
if recv_count > 0:
|
||||
@@ -321,26 +339,40 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
if is_stateless:
|
||||
for b in expert_weights_buffers:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = b[dst]
|
||||
op.group_peer = src
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
b[dst],
|
||||
src_global,
|
||||
)
|
||||
for b in expert_weights_buffers
|
||||
]
|
||||
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# wait for the communication to finish
|
||||
return (
|
||||
is_unchanged,
|
||||
|
||||
Reference in New Issue
Block a user