Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -28,11 +28,12 @@ from vllm.tracing import instrument
|
||||
from vllm.utils.async_utils import in_loop
|
||||
from vllm.utils.network_utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.v1.engine import (
|
||||
EEP_NOTIFICATION_CALL_ID,
|
||||
EEPNotificationType,
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
@@ -47,6 +48,7 @@ from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.utils import (
|
||||
CoreEngineActorManager,
|
||||
CoreEngineProcManager,
|
||||
get_engine_zmq_addresses,
|
||||
launch_core_engines,
|
||||
)
|
||||
from vllm.v1.executor import Executor
|
||||
@@ -445,6 +447,63 @@ class BackgroundResources:
|
||||
raise EngineDeadError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElasticScalingCache:
|
||||
existing_core_engines: list[EngineIdentity]
|
||||
num_new_core_engines: int
|
||||
pending_notifications: dict[EEPNotificationType, set[int]]
|
||||
|
||||
|
||||
def allocate_stateless_group_ports(parallel_config, new_data_parallel_size: int):
|
||||
"""
|
||||
Allocate stateless group ports for elastic EP.
|
||||
"""
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
|
||||
assert parallel_config.enable_elastic_ep, "Elastic EP must be enabled"
|
||||
world_size = parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_data_parallel_size
|
||||
num_world_groups = 1
|
||||
num_dp_groups = max(1, new_world_size_across_dp // new_data_parallel_size)
|
||||
num_ep_groups = max(
|
||||
1,
|
||||
new_world_size_across_dp
|
||||
// (new_data_parallel_size * parallel_config.tensor_parallel_size),
|
||||
)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_ports_needed = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3 + 5
|
||||
all_ports = get_open_ports_list(total_ports_needed)
|
||||
new_data_parallel_master_port_list = all_ports[-5:]
|
||||
all_ports = all_ports[:-5]
|
||||
new_stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
new_stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
new_stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
new_stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = new_stateless_dp_group_port_list
|
||||
parallel_config._stateless_ep_group_port_list = new_stateless_ep_group_port_list
|
||||
parallel_config._stateless_eplb_group_port_list = new_stateless_eplb_group_port_list
|
||||
parallel_config.data_parallel_master_port = new_data_parallel_master_port_list.pop()
|
||||
parallel_config._data_parallel_master_port_list = new_data_parallel_master_port_list
|
||||
|
||||
|
||||
class MPClient(EngineCoreClient):
|
||||
"""
|
||||
MPClient: base client for multi-proc EngineCore.
|
||||
@@ -491,32 +550,37 @@ class MPClient(EngineCoreClient):
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get("stats_update_address")
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
else:
|
||||
# Engines are managed by this client.
|
||||
with launch_core_engines(vllm_config, executor_class, log_stats) as (
|
||||
engine_manager,
|
||||
coordinator,
|
||||
addresses = get_engine_zmq_addresses(vllm_config)
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, addresses.outputs[0], zmq.PULL
|
||||
)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config,
|
||||
executor_class,
|
||||
log_stats,
|
||||
addresses,
|
||||
):
|
||||
) as (engine_manager, coordinator, addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
|
||||
(input_address,) = addresses.inputs
|
||||
(output_address,) = addresses.outputs
|
||||
self.stats_update_address = addresses.frontend_stats_publish_address
|
||||
if coordinator is not None:
|
||||
assert self.stats_update_address == (
|
||||
coordinator.get_stats_publish_address()
|
||||
)
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
)
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_index
|
||||
@@ -545,8 +609,13 @@ class MPClient(EngineCoreClient):
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for engines to send "
|
||||
"initial message on input socket."
|
||||
f"Timed out waiting for engine core processes to "
|
||||
f"start. This is often caused by slow weight loading "
|
||||
f"for large models. Waited "
|
||||
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
@@ -877,6 +946,10 @@ class AsyncMPClient(MPClient):
|
||||
output_socket = resources.output_socket
|
||||
assert output_socket is not None
|
||||
|
||||
notification_callback_handler: (
|
||||
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
|
||||
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
|
||||
|
||||
async def process_outputs_socket():
|
||||
try:
|
||||
while True:
|
||||
@@ -884,7 +957,26 @@ class AsyncMPClient(MPClient):
|
||||
resources.validate_alive(frames)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output, utility_results)
|
||||
if (
|
||||
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
|
||||
and notification_callback_handler is not None
|
||||
):
|
||||
assert _self_ref is not None
|
||||
_self = _self_ref()
|
||||
if not _self:
|
||||
return
|
||||
if outputs.utility_output.result is None:
|
||||
continue
|
||||
notification_data = outputs.utility_output.result.result
|
||||
assert isinstance(notification_data, Sequence)
|
||||
assert len(notification_data) == 2
|
||||
asyncio.create_task(
|
||||
notification_callback_handler(_self, notification_data)
|
||||
)
|
||||
else:
|
||||
_process_utility_output(
|
||||
outputs.utility_output, utility_results
|
||||
)
|
||||
continue
|
||||
|
||||
if output_handler is not None:
|
||||
@@ -1081,6 +1173,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
# Used only by DPLBAsyncMPClient subclass.
|
||||
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
|
||||
|
||||
self.eep_scaling_cache: ElasticScalingCache | None = None
|
||||
|
||||
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
||||
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
|
||||
@@ -1101,12 +1195,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
assert self.stats_update_address is not None
|
||||
stats_addr: str = self.stats_update_address
|
||||
assert len(self.engine_ranks_managed) > 0
|
||||
# NOTE: running and waiting counts are all global from
|
||||
# the Coordinator include all global EngineCores. This
|
||||
# slice includes just the cores managed by this client.
|
||||
count_slice = slice(
|
||||
self.engine_ranks_managed[0], self.engine_ranks_managed[-1] + 1
|
||||
)
|
||||
|
||||
async def run_engine_stats_update_task():
|
||||
with (
|
||||
@@ -1145,6 +1233,29 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
):
|
||||
# Extract new engine count from the decoded message
|
||||
new_engine_count = decoded[1]
|
||||
# Update engine_ranks_managed and count_slice
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
assert dp_rank == 0
|
||||
assert dp_size == new_engine_count
|
||||
assert not (
|
||||
parallel_config.data_parallel_hybrid_lb
|
||||
or parallel_config.data_parallel_external_lb
|
||||
)
|
||||
num_ranks = dp_size
|
||||
self.engine_ranks_managed = list(
|
||||
range(dp_rank, dp_rank + num_ranks)
|
||||
)
|
||||
if len(self.lb_engines) < new_engine_count:
|
||||
self.lb_engines = self.lb_engines + [
|
||||
[0, 0]
|
||||
for _ in range(
|
||||
new_engine_count - len(self.lb_engines)
|
||||
)
|
||||
]
|
||||
else:
|
||||
self.lb_engines = self.lb_engines[:new_engine_count]
|
||||
# Send scale up notification to coordinator
|
||||
scale_msg = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_engine_count)
|
||||
@@ -1178,6 +1289,11 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
if counts is not None:
|
||||
# Running and waiting counts are global from the
|
||||
# Coordinator including all EngineCores. Slice to get
|
||||
# just the cores managed by this client.
|
||||
ranks = self.engine_ranks_managed
|
||||
count_slice = slice(ranks[0], ranks[-1] + 1)
|
||||
sliced_counts = counts[count_slice]
|
||||
self.lb_engines = sliced_counts
|
||||
logger.debug(
|
||||
@@ -1287,6 +1403,67 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
for req_id in outputs.finished_requests:
|
||||
self.reqs_in_flight.pop(req_id, None)
|
||||
|
||||
@staticmethod
|
||||
async def eep_process_engine_core_notification(
|
||||
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
|
||||
):
|
||||
cache = self.eep_scaling_cache
|
||||
notification_type_str, dp_rank = notification_data
|
||||
try:
|
||||
notification_type = EEPNotificationType(notification_type_str)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Unknown EEP notification type: {notification_type_str}"
|
||||
) from e
|
||||
|
||||
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
|
||||
from vllm.v1.engine import UtilityResult
|
||||
|
||||
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
|
||||
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
|
||||
# all engine cores have completed reconfiguration.
|
||||
dummy_output = UtilityOutput(
|
||||
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
|
||||
)
|
||||
_process_utility_output(dummy_output, self.utility_results)
|
||||
return
|
||||
assert cache is not None
|
||||
if notification_type not in cache.pending_notifications:
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if dp_rank in cache.pending_notifications[notification_type]:
|
||||
raise ValueError(
|
||||
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
|
||||
)
|
||||
cache.pending_notifications[notification_type].add(dp_rank)
|
||||
if len(cache.pending_notifications[notification_type]) >= abs(
|
||||
cache.num_new_core_engines
|
||||
):
|
||||
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
assert cache.num_new_core_engines < 0
|
||||
old_dp_size = len(cache.existing_core_engines)
|
||||
new_dp_size = old_dp_size + cache.num_new_core_engines
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
old_dp_size, new_dp_size
|
||||
)
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self._call_utility_async(
|
||||
"eep_handle_engine_core_notification",
|
||||
notification_type,
|
||||
engine=engine,
|
||||
)
|
||||
for engine in cache.existing_core_engines
|
||||
]
|
||||
)
|
||||
cache.pending_notifications[notification_type] = set()
|
||||
if notification_type in [
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE,
|
||||
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
|
||||
]:
|
||||
self.eep_scaling_cache = None
|
||||
|
||||
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
||||
if not request_ids or self.resources.engine_dead:
|
||||
return
|
||||
@@ -1333,6 +1510,20 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
async def _eep_wait_for_setup_switch_complete(self) -> None:
|
||||
"""
|
||||
Wait for core engines to switch to the new setup.
|
||||
|
||||
In eep_process_engine_core_notification(), a dummy UtilityOutput with
|
||||
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
|
||||
notification is received from engine 0. We create a future with
|
||||
that call_id and wait for it to be resolved.
|
||||
"""
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
|
||||
self._ensure_output_queue_task()
|
||||
await future
|
||||
|
||||
async def _scale_up_elastic_ep(
|
||||
self, cur_data_parallel_size: int, new_data_parallel_size: int
|
||||
) -> None:
|
||||
@@ -1340,38 +1531,57 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
and reconfiguring existing ones."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
# Phase 1: Send reconfigure messages to all existing engines and wait
|
||||
# for them to be sent
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
# Phase 1: Send reconfig messages to existing engines
|
||||
reconfig_futures = []
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
for engine in self.core_engines:
|
||||
reconfig_request = ReconfigureDistributedRequest(
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
coro = self._call_utility_async(
|
||||
"reinitialize_distributed", reconfig_request, engine=engine
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
logger.info("All reconfigure messages sent, starting engine creation")
|
||||
|
||||
# Phase 2: Create new engines now that reconfig messages have been sent
|
||||
# self.resources.engine_manager is guaranteed to be
|
||||
# CoreEngineActorManager for RayDPClient
|
||||
# Phase 2: Create new engines
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_up_elastic_ep(
|
||||
self.vllm_config, new_data_parallel_size
|
||||
parallel_config.eplb_config.num_redundant_experts = 0
|
||||
start_new_worker_future = asyncio.to_thread(
|
||||
self.resources.engine_manager.scale_up_elastic_ep,
|
||||
self.vllm_config,
|
||||
new_data_parallel_size,
|
||||
)
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
# Phase 3: Wait for new engines to be created
|
||||
# and reconfig messages to be received
|
||||
await asyncio.gather(start_new_worker_future, *reconfig_futures)
|
||||
logger.info("[Elastic EP] Successfully started new engines")
|
||||
|
||||
# Create new CoreEngine objects for the new engines
|
||||
new_engine_identities = set()
|
||||
for i in range(cur_data_parallel_size, new_data_parallel_size):
|
||||
new_engine = i.to_bytes(2, "little")
|
||||
self.core_engines.append(new_engine)
|
||||
# NOTE(yongji): we don't update lb_engines here,
|
||||
# we let run_engine_stats_update_task to update it.
|
||||
new_engine_identities.add(new_engine)
|
||||
|
||||
# Wait for ready messages from new engines on the input socket
|
||||
@@ -1381,16 +1591,21 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
|
||||
):
|
||||
raise TimeoutError(
|
||||
"Timed out waiting for new engines to send initial "
|
||||
"message on input socket."
|
||||
f"Timed out waiting for new engine core processes to "
|
||||
f"start. Waited "
|
||||
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
|
||||
f"timeout, set the environment variable: "
|
||||
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
|
||||
)
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
new_engine_identities.discard(identity)
|
||||
|
||||
# Phase 3: Wait for all existing engines to complete reconfiguration
|
||||
logger.info("Waiting for existing engines to complete reconfiguration")
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
# NOTE(yongji): Before we schedule any requests on the new workers,
|
||||
# we should wait for them to switch to the new setup.
|
||||
await wait_future
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# Notify coordinator about scale up through existing
|
||||
# stats_update_task connection
|
||||
self._ensure_stats_update_task()
|
||||
@@ -1399,8 +1614,6 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_up_marker)
|
||||
|
||||
# Update the parallel config
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
logger.info(
|
||||
"[Elastic EP] Scale up completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
@@ -1413,7 +1626,14 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
reconfiguring existing engine cores."""
|
||||
cur_data_parallel_size = len(self.core_engines)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_master_port = get_open_port()
|
||||
self.eep_scaling_cache = ElasticScalingCache(
|
||||
existing_core_engines=self.core_engines.copy(),
|
||||
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
|
||||
pending_notifications=dict(),
|
||||
)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
allocate_stateless_group_ports(parallel_config, new_data_parallel_size)
|
||||
|
||||
reconfig_futures = []
|
||||
for cur_dp_rank, engine in enumerate(self.core_engines):
|
||||
@@ -1421,8 +1641,13 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
new_data_parallel_size=new_data_parallel_size,
|
||||
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
|
||||
new_data_parallel_master_ip=self.vllm_config.parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=self.vllm_config.parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_ip=parallel_config.data_parallel_master_ip,
|
||||
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
|
||||
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
|
||||
new_stateless_world_group_port_list=parallel_config._stateless_world_group_port_list,
|
||||
new_stateless_dp_group_port_list=parallel_config._stateless_dp_group_port_list,
|
||||
new_stateless_ep_group_port_list=parallel_config._stateless_ep_group_port_list,
|
||||
new_stateless_eplb_group_port_list=parallel_config._stateless_eplb_group_port_list,
|
||||
)
|
||||
if cur_dp_rank >= new_data_parallel_size:
|
||||
reconfig_request.new_data_parallel_rank = (
|
||||
@@ -1433,23 +1658,24 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
)
|
||||
reconfig_futures.append(asyncio.create_task(coro))
|
||||
|
||||
for _ in range(new_data_parallel_size, cur_data_parallel_size):
|
||||
self.core_engines.pop()
|
||||
# NOTE(yongji): Immediately stop sending requests to the removing engines.
|
||||
self.core_engines = self.core_engines[:new_data_parallel_size]
|
||||
self.lb_engines = self.lb_engines[:new_data_parallel_size]
|
||||
wait_future = self._eep_wait_for_setup_switch_complete()
|
||||
|
||||
await asyncio.gather(*reconfig_futures)
|
||||
|
||||
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
|
||||
self.resources.engine_manager.scale_down_elastic_ep(
|
||||
cur_data_parallel_size, new_data_parallel_size
|
||||
)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
self._ensure_stats_update_task()
|
||||
scale_down_marker = msgspec.msgpack.encode(
|
||||
("SCALE_ELASTIC_EP", new_data_parallel_size)
|
||||
)
|
||||
await self.first_req_send_socket.send(scale_down_marker)
|
||||
|
||||
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
|
||||
# NOTE(yongji): Unlike scaling up,
|
||||
# here we don't actually need to wait for the setup switch to complete.
|
||||
# We may want to remove it in the future.
|
||||
await wait_future
|
||||
logger.info(
|
||||
"[Elastic EP] Scale down completed, new data parallel size: %s",
|
||||
new_data_parallel_size,
|
||||
|
||||
Reference in New Issue
Block a user