### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -18,8 +18,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.logger import logger
|
||||
|
||||
from .netloader_pg import (destroy_stateless_process_group,
|
||||
stateless_init_process_group)
|
||||
from .netloader_pg import destroy_stateless_process_group, stateless_init_process_group
|
||||
|
||||
|
||||
class P2PLoad:
|
||||
@@ -56,9 +55,7 @@ class P2PLoad:
|
||||
- The model if loading is successful, otherwise None.
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
receiver_pg = None
|
||||
loaded_model = None
|
||||
try:
|
||||
@@ -67,15 +64,13 @@ class P2PLoad:
|
||||
port=self.source_port,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
group_name='netloader',
|
||||
group_name="netloader",
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
@@ -84,14 +79,11 @@ class P2PLoad:
|
||||
if len(param.shape) == 0:
|
||||
continue
|
||||
receiver_pg.recv([param], 1, 0).wait()
|
||||
torch.distributed.barrier(group=receiver_pg,
|
||||
device_ids=[model_device.index])
|
||||
torch.distributed.barrier(group=receiver_pg, device_ids=[model_device.index])
|
||||
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
|
||||
logger.info(
|
||||
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}")
|
||||
loaded_model = model
|
||||
except Exception as e:
|
||||
logger.error("Failed to recv model: {}".format(e))
|
||||
@@ -129,9 +121,7 @@ class P2PSend:
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
torch.npu.set_device(model_device)
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
sender_pg = None
|
||||
try:
|
||||
sender_pg = stateless_init_process_group(
|
||||
@@ -139,14 +129,10 @@ class P2PSend:
|
||||
port=self.listen_port,
|
||||
rank=1,
|
||||
world_size=2,
|
||||
group_name='netloader',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(
|
||||
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
group_name="netloader",
|
||||
)
|
||||
logger.info(f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
logger.info(f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
@@ -155,16 +141,12 @@ class P2PSend:
|
||||
if "aclnn_input_scale" in name:
|
||||
continue
|
||||
if name in int8_params:
|
||||
sender_pg.send([int8_params[name].to(model_device)], 0,
|
||||
0).wait()
|
||||
sender_pg.send([int8_params[name].to(model_device)], 0, 0).wait()
|
||||
else:
|
||||
sender_pg.send([param.contiguous()], 0, 0).wait()
|
||||
torch.distributed.barrier(group=sender_pg,
|
||||
device_ids=[model_device.index])
|
||||
torch.distributed.barrier(group=sender_pg, device_ids=[model_device.index])
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
logger.info(
|
||||
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}")
|
||||
finally:
|
||||
if sender_pg:
|
||||
destroy_stateless_process_group(sender_pg)
|
||||
destroy_stateless_process_group(sender_pg)
|
||||
|
||||
@@ -17,16 +17,13 @@
|
||||
import gc
|
||||
import ipaddress
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch._C._distributed_c10d import (_DEFAULT_PG_TIMEOUT,
|
||||
_register_process_group,
|
||||
_unregister_process_group)
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT, _register_process_group, _unregister_process_group
|
||||
from torch.distributed import ProcessGroup, is_hccl_available
|
||||
from torch.distributed.distributed_c10d import (Backend, BackendConfig,
|
||||
PrefixStore, _world)
|
||||
from torch.distributed.distributed_c10d import Backend, BackendConfig, PrefixStore, _world
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
|
||||
from vllm.logger import logger
|
||||
@@ -39,7 +36,7 @@ def stateless_init_process_group(
|
||||
rank: int,
|
||||
timeout: timedelta = _DEFAULT_PG_TIMEOUT,
|
||||
group_name: str = "",
|
||||
pg_options: Optional[Any] = None,
|
||||
pg_options: Any | None = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Initializes a stateless process group.
|
||||
@@ -57,7 +54,8 @@ def stateless_init_process_group(
|
||||
ProcessGroup: The initialized process group.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If world_size is not positive, or if rank is not within [0, world_size - 1], or if HCCL is unavailable.
|
||||
RuntimeError: If world_size is not positive, or if rank is not within
|
||||
[0, world_size - 1], or if HCCL is unavailable.
|
||||
TypeError: If timeout is not a timedelta type.
|
||||
ValueError: If group_name already exists.
|
||||
"""
|
||||
@@ -67,21 +65,18 @@ def stateless_init_process_group(
|
||||
raise RuntimeError("world_size must be positive")
|
||||
# Check if rank is within [0, world_size - 1]
|
||||
if not (rank >= 0 and rank <= world_size - 1):
|
||||
raise RuntimeError(
|
||||
"rank should be a number between 0 and ``world_size``-1")
|
||||
raise RuntimeError("rank should be a number between 0 and ``world_size``-1")
|
||||
# Check if HCCL is available
|
||||
if not is_hccl_available():
|
||||
raise RuntimeError("HCCL is not available")
|
||||
# Check if timeout is a timedelta type
|
||||
if not isinstance(timeout, timedelta):
|
||||
raise TypeError(
|
||||
f"Expected timeout argument to be of type datetime.timedelta, got {timeout}"
|
||||
)
|
||||
raise TypeError(f"Expected timeout argument to be of type datetime.timedelta, got {timeout}")
|
||||
# Check if group_name already exists
|
||||
if group_name in _world.pg_names.values():
|
||||
raise ValueError(
|
||||
f"The specified group name {group_name} has already been "
|
||||
"created, please use a different group name")
|
||||
f"The specified group name {group_name} has already been created, please use a different group name"
|
||||
)
|
||||
|
||||
# Function to check if an IPv6 address is valid
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
@@ -101,10 +96,9 @@ def stateless_init_process_group(
|
||||
# Get initialization method
|
||||
init_method = get_tcp_uri(host, port)
|
||||
# Create Backend object
|
||||
backend = Backend('hccl')
|
||||
backend = Backend("hccl")
|
||||
# Use rendezvous function to get store, rank, and world_size
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store, rank, world_size = next(rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
|
||||
# Set timeout for store
|
||||
store.set_timeout(timeout)
|
||||
@@ -125,9 +119,7 @@ def stateless_init_process_group(
|
||||
pg._set_default_backend(Backend.backend_type_map[backend])
|
||||
|
||||
# Check if pg_options is None or not of type ProcessGroupHCCL.Options
|
||||
if pg_options is None or not isinstance(
|
||||
pg_options,
|
||||
torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||
if pg_options is None or not isinstance(pg_options, torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options):
|
||||
pg_options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
|
||||
# Set attributes for pg_options
|
||||
pg_options.is_high_priority_stream = False
|
||||
@@ -135,8 +127,7 @@ def stateless_init_process_group(
|
||||
pg_options.global_ranks_in_group = []
|
||||
pg_options.group_id = f"{init_method}/{group_name}/"
|
||||
# Create ProcessGroupHCCL object
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
|
||||
pg_options)
|
||||
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, pg_options)
|
||||
# Set sequence number for backend_class
|
||||
backend_class._set_sequence_number_for_group()
|
||||
# Set backend_type
|
||||
@@ -176,9 +167,10 @@ def destroy_stateless_process_group(pg: ProcessGroup, manual_gc: bool = False):
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
# Check if pg is in keys of _world.pg_coalesce_state
|
||||
if pg in _world.pg_coalesce_state.keys():
|
||||
logger.warning("Some coalesced collectives haven't been launched when "
|
||||
"ProcessGroup is destroyed. They will be cleaned.")
|
||||
if pg in _world.pg_coalesce_state:
|
||||
logger.warning(
|
||||
"Some coalesced collectives haven't been launched when ProcessGroup is destroyed. They will be cleaned."
|
||||
)
|
||||
del _world.pg_coalesce_state[pg]
|
||||
# Unregister the process group
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
@@ -18,7 +18,7 @@ import json
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
from typing import List, Optional, Tuple
|
||||
from contextlib import suppress
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
@@ -32,8 +32,7 @@ class ElasticClient:
|
||||
Class for handling the client-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: list[str], device_id: int, model_path: str,
|
||||
tp: int, pp: int):
|
||||
def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int):
|
||||
"""
|
||||
Initializes the ElasticClient instance.
|
||||
|
||||
@@ -50,14 +49,14 @@ class ElasticClient:
|
||||
self.tp = tp
|
||||
self.pp = pp
|
||||
|
||||
self.s: Optional[socket.socket] = None
|
||||
self.ack: Optional[Tuple[str, int]] = None
|
||||
self.server_addr: Optional[str] = None
|
||||
self.server_port: Optional[int] = None
|
||||
self.s: socket.socket | None = None
|
||||
self.ack: tuple[str, int] | None = None
|
||||
self.server_addr: str | None = None
|
||||
self.server_port: int | None = None
|
||||
|
||||
for source in self.sources:
|
||||
try:
|
||||
ip, port_str = source.split(':')
|
||||
ip, port_str = source.split(":")
|
||||
port = int(port_str)
|
||||
except Exception as e:
|
||||
logger.info(f"IP format error: {source}, detail: {e}")
|
||||
@@ -68,13 +67,9 @@ class ElasticClient:
|
||||
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
logger.info(
|
||||
f"Start connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}")
|
||||
sock.connect((self.server_addr, self.server_port))
|
||||
logger.info(
|
||||
f"Finish connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}")
|
||||
sock.settimeout(60)
|
||||
|
||||
self.s = sock
|
||||
@@ -83,10 +78,8 @@ class ElasticClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Connect to {source} fails, detail: {e}")
|
||||
if sock is not None:
|
||||
try:
|
||||
with suppress(Exception):
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.s = None
|
||||
self.ack = None
|
||||
self.server_addr = None
|
||||
@@ -120,10 +113,8 @@ class ElasticClient:
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
try:
|
||||
with suppress(Exception):
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send_str(self, data_str: str) -> None:
|
||||
"""
|
||||
@@ -151,8 +142,7 @@ class ElasticClient:
|
||||
data_str = self.s.recv(buffer_size).decode("utf-8")
|
||||
return data_str
|
||||
|
||||
def register(self, device_id: int, model_path: str, tp: int,
|
||||
pp: int) -> Tuple[str, int]:
|
||||
def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]:
|
||||
"""
|
||||
Registers the client with the server.
|
||||
|
||||
@@ -168,20 +158,13 @@ class ElasticClient:
|
||||
free_port = find_free_port()
|
||||
data = {
|
||||
"label": "JOIN",
|
||||
"content": {
|
||||
'device_id': device_id,
|
||||
'model_path': model_path,
|
||||
'tp': tp,
|
||||
'pp': pp,
|
||||
'port': free_port
|
||||
}
|
||||
"content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port},
|
||||
}
|
||||
|
||||
try:
|
||||
self.send_str(json.dumps(data))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Send data {data} to server fails, detail: {e}")
|
||||
raise RuntimeError(f"Send data {data} to server fails, detail: {e}")
|
||||
|
||||
try:
|
||||
ack_str = self.recv_str()
|
||||
@@ -191,23 +174,22 @@ class ElasticClient:
|
||||
try:
|
||||
ack = json.loads(ack_str)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
|
||||
)
|
||||
raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}")
|
||||
|
||||
logger.info(f"Receive ack: {ack}")
|
||||
|
||||
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
|
||||
and ack["content"] is not None and "name" in ack["content"]):
|
||||
if (
|
||||
"label" in ack
|
||||
and ack["label"] == "JOIN_ACK"
|
||||
and "content" in ack
|
||||
and ack["content"] is not None
|
||||
and "name" in ack["content"]
|
||||
):
|
||||
return (ack["content"]["name"], free_port)
|
||||
elif ("label" in ack and ack["label"] == 'JOIN_NACK'
|
||||
and "content" in ack):
|
||||
raise RuntimeError(
|
||||
f"Receive nack from server, reason: {ack['content']}")
|
||||
elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack:
|
||||
raise RuntimeError(f"Receive nack from server, reason: {ack['content']}")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Receive ack {ack} from server does not contain required fields"
|
||||
)
|
||||
raise RuntimeError(f"Receive ack {ack} from server does not contain required fields")
|
||||
|
||||
|
||||
class ElasticServer:
|
||||
@@ -215,9 +197,18 @@ class ElasticServer:
|
||||
Class for handling the server-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, addr: str, port: int, model, device_id: int,
|
||||
model_path: str, tp: int, pp: int, int8_cache: str,
|
||||
int8_cache_name: Optional[List[str]]):
|
||||
def __init__(
|
||||
self,
|
||||
addr: str,
|
||||
port: int,
|
||||
model,
|
||||
device_id: int,
|
||||
model_path: str,
|
||||
tp: int,
|
||||
pp: int,
|
||||
int8_cache: str,
|
||||
int8_cache_name: list[str] | None,
|
||||
):
|
||||
"""
|
||||
Initializes the ElasticServer instance.
|
||||
|
||||
@@ -246,30 +237,25 @@ class ElasticServer:
|
||||
self.pp = pp
|
||||
|
||||
self.original_int8 = {}
|
||||
int8_pattern = "|".join(
|
||||
map(re.escape,
|
||||
int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||
int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.dtype == torch.int8:
|
||||
if int8_cache == 'hbm':
|
||||
if int8_cache == "hbm":
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||
):
|
||||
try:
|
||||
self.original_int8[name] = param.data.clone(
|
||||
).detach()
|
||||
self.original_int8[name] = param.data.clone().detach()
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
|
||||
)
|
||||
logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}")
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
|
||||
elif int8_cache == 'dram':
|
||||
elif int8_cache == "dram":
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||
):
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
elif int8_cache == 'no':
|
||||
elif int8_cache == "no":
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -277,14 +263,18 @@ class ElasticServer:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
|
||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, "
|
||||
f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, "
|
||||
f"int8 params {list(self.original_int8)} are saved to {int8_cache}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
self.s.close()
|
||||
if self.s is not None:
|
||||
with suppress(Exception):
|
||||
self.s.close()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -343,10 +333,7 @@ class ElasticServer:
|
||||
if not all(k in content for k in required_keys):
|
||||
return False
|
||||
port = content["port"]
|
||||
if not (isinstance(port, int) or
|
||||
(isinstance(port, str) and port.isdigit())):
|
||||
return False
|
||||
return True
|
||||
return isinstance(port, int) or (isinstance(port, str) and port.isdigit())
|
||||
|
||||
comm_name = None
|
||||
if is_valid_data(data):
|
||||
@@ -355,36 +342,31 @@ class ElasticServer:
|
||||
tp = int(data["content"]["tp"])
|
||||
pp = int(data["content"]["pp"])
|
||||
|
||||
if int(self.device_id
|
||||
) == device_id and self.model_path == model_path and int(
|
||||
self.tp) == tp and int(self.pp) == pp:
|
||||
if (
|
||||
int(self.device_id) == device_id
|
||||
and self.model_path == model_path
|
||||
and int(self.tp) == tp
|
||||
and int(self.pp) == pp
|
||||
):
|
||||
comm_name = str(addr[0]) + ":" + str(addr[1])
|
||||
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
)
|
||||
server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp))
|
||||
client_desc = (device_id, model_path, tp, pp)
|
||||
msg = f"Received data {client_desc} does not consist with this server {server_desc}"
|
||||
logger.warning(msg)
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
"label": "JOIN_NACK",
|
||||
"content": msg,
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data does not contain required fields: {data}")
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data does not contain required fields: {data}"
|
||||
}
|
||||
logger.warning(f"Received data does not contain required fields: {data}")
|
||||
ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"}
|
||||
|
||||
try:
|
||||
ack_str = json.dumps(ack).encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert {ack} to JSON format, details: {e}")
|
||||
logger.error(f"Failed to convert {ack} to JSON format, details: {e}")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
@@ -395,14 +377,10 @@ class ElasticServer:
|
||||
conn.close()
|
||||
return
|
||||
|
||||
if ack["content"] and isinstance(ack["content"],
|
||||
dict) and 'name' in ack["content"]:
|
||||
if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]:
|
||||
try:
|
||||
p2psend = P2PSend(self.addr, data["content"]["port"],
|
||||
ack["content"]["name"])
|
||||
p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"])
|
||||
p2psend.send(self.model, self.original_int8)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"P2PSend Failed to send model to {self.addr}, details: {e}"
|
||||
)
|
||||
logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}")
|
||||
conn.close()
|
||||
|
||||
@@ -48,36 +48,27 @@ def elastic_load(
|
||||
# Filter sources for the current device
|
||||
sources_this_device = []
|
||||
for s in sources:
|
||||
if isinstance(
|
||||
s, dict
|
||||
) and "device_id" in s and s["device_id"] == device_id and isinstance(
|
||||
s["sources"], list):
|
||||
if isinstance(s, dict) and "device_id" in s and s["device_id"] == device_id and isinstance(s["sources"], list):
|
||||
sources_this_device += s["sources"]
|
||||
if len(sources_this_device) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize the interaction layer with the ElasticClient
|
||||
with ElasticClient(sources_this_device, device_id, model_path, tp,
|
||||
pp) as client_interaction_layer:
|
||||
with ElasticClient(sources_this_device, device_id, model_path, tp, pp) as client_interaction_layer:
|
||||
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
|
||||
raise RuntimeError(
|
||||
"Failed to initialize ElasticClient: socket or server_addr is None"
|
||||
)
|
||||
raise RuntimeError("Failed to initialize ElasticClient: socket or server_addr is None")
|
||||
ack = client_interaction_layer.ack
|
||||
if ack is None:
|
||||
raise RuntimeError("ElasticClient.register did not return ack")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
elastic_loader = P2PLoad(ack[0],
|
||||
client_interaction_layer.server_addr,
|
||||
ack[1])
|
||||
elastic_loader = P2PLoad(ack[0], client_interaction_layer.server_addr, ack[1])
|
||||
model_loaded = elastic_loader.load(model=model)
|
||||
if model_loaded is None:
|
||||
logger.error("Failed to load model")
|
||||
return None
|
||||
logger.info("Finish elastic load (duration: {}s)".format(
|
||||
time.perf_counter() - t0))
|
||||
logger.info("Finish elastic load (duration: {}s)".format(time.perf_counter() - t0))
|
||||
return model_loaded
|
||||
except Exception as e:
|
||||
logger.info(f"elastic_load error: {e}")
|
||||
|
||||
@@ -18,7 +18,6 @@ import gc
|
||||
import json
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -27,8 +26,7 @@ from vllm.logger import logger
|
||||
from vllm.model_executor.model_loader import register_model_loader
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading)
|
||||
from vllm.model_executor.model_loader.utils import initialize_model, process_weights_after_loading
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from .interaction.elastic import ElasticServer
|
||||
@@ -41,12 +39,13 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
"""
|
||||
A model loader that uses elastic loading for loading weights.
|
||||
"""
|
||||
source: Optional[List[dict]]
|
||||
model_path: Optional[str]
|
||||
listen_port: Optional[int]
|
||||
|
||||
source: list[dict] | None
|
||||
model_path: str | None
|
||||
listen_port: int | None
|
||||
int8_cache: str
|
||||
int8_cache_name: Optional[List[str]]
|
||||
output_prefix: Optional[str]
|
||||
int8_cache_name: list[str] | None
|
||||
output_prefix: str | None
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
"""
|
||||
@@ -63,18 +62,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
extra = load_config.model_loader_extra_config
|
||||
if extra and "CONFIG_FILE" in extra:
|
||||
try:
|
||||
logger.info(
|
||||
f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..."
|
||||
)
|
||||
with open(extra["CONFIG_FILE"], 'r') as f:
|
||||
logger.info(f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ...")
|
||||
with open(extra["CONFIG_FILE"]) as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logger.error("CONFIG_FILE not found")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("CONFIG_FILE is not a valid JSON file")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while reading CONFIG_FILE: {e}")
|
||||
logger.error(f"Unexpected error while reading CONFIG_FILE: {e}")
|
||||
|
||||
if config is None and extra:
|
||||
logger.info("Reading configs in model_loader_extra_config ...")
|
||||
@@ -82,19 +78,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
config = config or {}
|
||||
|
||||
for key, attr, checker, caster, default in [
|
||||
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v,
|
||||
None),
|
||||
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v,
|
||||
None),
|
||||
("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or
|
||||
(isinstance(v, str) and v.isdigit()), lambda v: int(v), None),
|
||||
("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v.
|
||||
lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'),
|
||||
("INT8_CACHE_NAME", "int8_cache_name",
|
||||
lambda v: isinstance(v, list), lambda v: v, None),
|
||||
("OUTPUT_PREFIX", "output_prefix",
|
||||
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
|
||||
lambda v: v, None),
|
||||
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, None),
|
||||
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, None),
|
||||
(
|
||||
"LISTEN_PORT",
|
||||
"listen_port",
|
||||
lambda v: isinstance(v, int) or (isinstance(v, str) and v.isdigit()),
|
||||
lambda v: int(v),
|
||||
None,
|
||||
),
|
||||
(
|
||||
"INT8_CACHE",
|
||||
"int8_cache",
|
||||
lambda v: isinstance(v, str) and v.lower() in ["hbm", "dram", "no"],
|
||||
lambda v: v.lower(),
|
||||
"no",
|
||||
),
|
||||
("INT8_CACHE_NAME", "int8_cache_name", lambda v: isinstance(v, list), lambda v: v, None),
|
||||
(
|
||||
"OUTPUT_PREFIX",
|
||||
"output_prefix",
|
||||
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
|
||||
lambda v: v,
|
||||
None,
|
||||
),
|
||||
]:
|
||||
v = config.get(key, default)
|
||||
if not checker(v):
|
||||
@@ -116,8 +123,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
self.output_prefix,
|
||||
)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module:
|
||||
"""
|
||||
Loads the model using the specified configuration.
|
||||
|
||||
@@ -140,15 +146,18 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
|
||||
device_id = torch.distributed.get_rank()
|
||||
|
||||
if (self.source is None or not isinstance(self.source, list)
|
||||
or device_id not in [
|
||||
one_device["device_id"] for one_device in self.source if
|
||||
isinstance(one_device, dict) and "device_id" in one_device
|
||||
]):
|
||||
logger.warning(
|
||||
"Did not get valid source info, use DefaultModelLoader")
|
||||
model, need_process_weights_after_loading = self.revert_to_default(
|
||||
model_config, vllm_config, device_config)
|
||||
if (
|
||||
self.source is None
|
||||
or not isinstance(self.source, list)
|
||||
or device_id
|
||||
not in [
|
||||
one_device["device_id"]
|
||||
for one_device in self.source
|
||||
if isinstance(one_device, dict) and "device_id" in one_device
|
||||
]
|
||||
):
|
||||
logger.warning("Did not get valid source info, use DefaultModelLoader")
|
||||
model, need_process_weights_after_loading = self.revert_to_default(model_config, vllm_config, device_config)
|
||||
|
||||
else:
|
||||
target_device = torch.device(device_config.device)
|
||||
@@ -158,8 +167,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
|
||||
|
||||
start_elastic_load = time.perf_counter()
|
||||
model = elastic_load(
|
||||
@@ -171,43 +179,39 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
pp=parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
end_elastic_load = time.perf_counter()
|
||||
logger.info(
|
||||
f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
|
||||
)
|
||||
logger.info(f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}")
|
||||
need_process_weights_after_loading = True
|
||||
|
||||
if model is None:
|
||||
logger.warning(
|
||||
"Netloader elastic loading fails, use load format DefaultModelLoader"
|
||||
)
|
||||
logger.warning("Netloader elastic loading fails, use load format DefaultModelLoader")
|
||||
|
||||
vllm_config = vllm_config_backup
|
||||
model_config = model_config_backup
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
if device_config.device_type == 'npu':
|
||||
if device_config.device_type == "npu":
|
||||
logger.info("Empty NPU cache")
|
||||
torch.npu.empty_cache()
|
||||
elif device_config.device_type == 'cuda':
|
||||
elif device_config.device_type == "cuda":
|
||||
logger.info("Empty CUDA cache")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model, need_process_weights_after_loading = self.revert_to_default(
|
||||
model_config, vllm_config, device_config)
|
||||
model_config, vllm_config, device_config
|
||||
)
|
||||
|
||||
start_elastic_server = time.perf_counter()
|
||||
# start elastic server
|
||||
if model is not None and (
|
||||
(self.listen_port and self.listen_port in range(1024, 65535)) or
|
||||
(self.listen_port is None)):
|
||||
|
||||
(self.listen_port and self.listen_port in range(1024, 65535)) or (self.listen_port is None)
|
||||
):
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
driver_ip = get_ip()
|
||||
|
||||
if driver_ip == '0.0.0.0':
|
||||
logger.error(
|
||||
"Driver IP is not set, skip to start Netloader server")
|
||||
if driver_ip == "0.0.0.0":
|
||||
logger.error("Driver IP is not set, skip to start Netloader server")
|
||||
else:
|
||||
if self.listen_port is None:
|
||||
self.listen_port = find_free_port()
|
||||
@@ -220,21 +224,14 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
|
||||
if self.output_prefix is not None:
|
||||
try:
|
||||
with open(self.output_prefix + str(device_id) + '.txt',
|
||||
'w') as file:
|
||||
with open(self.output_prefix + str(device_id) + ".txt", "w") as file:
|
||||
file.write(f"{driver_ip}:{self.listen_port}")
|
||||
logger.info(
|
||||
f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
|
||||
)
|
||||
logger.info(f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}")
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
f"File path {self.output_prefix + str(device_id)} does not exist."
|
||||
)
|
||||
logger.error(f"File path {self.output_prefix + str(device_id)} does not exist.")
|
||||
except PermissionError:
|
||||
logger.error(
|
||||
f"No permission to write to file {self.output_prefix + str(device_id)}."
|
||||
)
|
||||
except IOError as e:
|
||||
logger.error(f"No permission to write to file {self.output_prefix + str(device_id)}.")
|
||||
except OSError as e:
|
||||
logger.error(
|
||||
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
|
||||
)
|
||||
@@ -242,31 +239,30 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
logger.error(f"Unknown error: {e}")
|
||||
|
||||
try:
|
||||
assert isinstance(
|
||||
self.listen_port, int
|
||||
), f"listen port should be int but get {self.listen_port}"
|
||||
assert isinstance(self.listen_port, int), f"listen port should be int but get {self.listen_port}"
|
||||
|
||||
elastic_server = ElasticServer(
|
||||
driver_ip, self.listen_port, model, device_id,
|
||||
self.model_path, parallel_config.tensor_parallel_size,
|
||||
driver_ip,
|
||||
self.listen_port,
|
||||
model,
|
||||
device_id,
|
||||
self.model_path,
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
self.int8_cache, self.int8_cache_name)
|
||||
self.int8_cache,
|
||||
self.int8_cache_name,
|
||||
)
|
||||
elastic_server.start()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to start Netloader server for rank: {device_id}, details: {e}"
|
||||
)
|
||||
logger.error(f"Failed to start Netloader server for rank: {device_id}, details: {e}")
|
||||
else:
|
||||
logger.info("Skip to start Netloader server")
|
||||
|
||||
end_elastic_server = time.perf_counter()
|
||||
logger.info(
|
||||
f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
|
||||
)
|
||||
logger.info(f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}")
|
||||
|
||||
if need_process_weights_after_loading:
|
||||
process_weights_after_loading(model, model_config,
|
||||
torch.device(device_config.device))
|
||||
process_weights_after_loading(model, model_config, torch.device(device_config.device))
|
||||
|
||||
if model is None:
|
||||
logger.error("NetLoader elastic loads model fails")
|
||||
@@ -274,8 +270,7 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
|
||||
return model.eval()
|
||||
|
||||
def revert_to_default(self, model_config, vllm_config,
|
||||
device_config) -> Tuple[nn.Module, bool]:
|
||||
def revert_to_default(self, model_config, vllm_config, device_config) -> tuple[nn.Module, bool]:
|
||||
"""
|
||||
Reverts to the default model loading logic when elastic loading fails or is not applicable.
|
||||
|
||||
@@ -300,19 +295,15 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
default_model_loader = DefaultModelLoader(self.load_config)
|
||||
|
||||
if model_config.quantization is None:
|
||||
model = default_model_loader.load_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
model = default_model_loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
need_process_weights_after_loading = False
|
||||
else:
|
||||
logger.warning(
|
||||
"Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
|
||||
)
|
||||
logger.warning("Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading ")
|
||||
need_process_weights_after_loading = True
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
model = initialize_model(vllm_config=vllm_config, model_config=model_config)
|
||||
default_model_loader.load_weights(model, model_config)
|
||||
model = model.eval()
|
||||
|
||||
@@ -321,6 +312,5 @@ class ModelNetLoaderElastic(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
@@ -29,7 +29,7 @@ def find_free_port():
|
||||
- A free port number.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
@@ -47,20 +47,17 @@ def is_valid_path_prefix(path_prefix):
|
||||
return False
|
||||
|
||||
if re.search(r'[<>:"|?*]', path_prefix):
|
||||
logger.warning(
|
||||
f'The path prefix {path_prefix} contains illegal characters.')
|
||||
logger.warning(f"The path prefix {path_prefix} contains illegal characters.")
|
||||
return False
|
||||
|
||||
if path_prefix.startswith('/') or path_prefix.startswith('\\'):
|
||||
if path_prefix.startswith("/") or path_prefix.startswith("\\"):
|
||||
if not os.path.exists(os.path.dirname(path_prefix)):
|
||||
logger.warning(
|
||||
f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
|
||||
)
|
||||
logger.warning(f"The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.")
|
||||
return False
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
|
||||
logger.warning(
|
||||
f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.'
|
||||
f"The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user