### What this PR does / why we need it?
| File Path |
| :--- |
| ` vllm_ascend/eplb/adaptor/abstract_adaptor.py` |
| ` vllm_ascend/eplb/adaptor/vllm_adaptor.py` |
| ` vllm_ascend/eplb/core/eplb_device_transfer_loader.py` |
| ` vllm_ascend/eplb/core/eplb_utils.py` |
| ` vllm_ascend/eplb/core/eplb_worker.py` |
| ` vllm_ascend/eplb/core/policy/policy_abstract.py` |
| ` vllm_ascend/eplb/core/policy/policy_default_eplb.py` |
| ` vllm_ascend/eplb/core/policy/policy_factory.py` |
| ` vllm_ascend/eplb/core/policy/policy_flashlb.py` |
| ` vllm_ascend/eplb/core/policy/policy_random.py` |
| ` vllm_ascend/eplb/core/policy/policy_swift_balancer.py` |
| ` vllm_ascend/eplb/eplb_updator.py` |
| ` vllm_ascend/eplb/utils.py` |
| ` vllm_ascend/model_loader/netloader/executor/elastic_load.py` |
| ` vllm_ascend/model_loader/netloader/executor/netloader_pg.py` |
| ` vllm_ascend/model_loader/netloader/interaction/elastic.py` |
| ` vllm_ascend/model_loader/netloader/load.py` |
| ` vllm_ascend/model_loader/netloader/netloader.py` |
| ` vllm_ascend/model_loader/netloader/utils.py` |
| ` vllm_ascend/patch/platform/__init__.py` |
| ` vllm_ascend/patch/platform/patch_balance_schedule.py` |
| ` vllm_ascend/patch/platform/patch_ec_connector.py` |
| ` vllm_ascend/patch/platform/patch_mamba_config.py` |
| ` vllm_ascend/patch/platform/patch_multiproc_executor.py` |
| ` vllm_ascend/patch/platform/patch_sched_yield.py` |
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: MrZ20 <2609716663@qq.com>
This commit is contained in:
@@ -18,7 +18,7 @@ import json
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
from typing import List, Optional, Tuple
|
||||
from contextlib import suppress
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
@@ -32,8 +32,7 @@ class ElasticClient:
|
||||
Class for handling the client-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: list[str], device_id: int, model_path: str,
|
||||
tp: int, pp: int):
|
||||
def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int):
|
||||
"""
|
||||
Initializes the ElasticClient instance.
|
||||
|
||||
@@ -50,14 +49,14 @@ class ElasticClient:
|
||||
self.tp = tp
|
||||
self.pp = pp
|
||||
|
||||
self.s: Optional[socket.socket] = None
|
||||
self.ack: Optional[Tuple[str, int]] = None
|
||||
self.server_addr: Optional[str] = None
|
||||
self.server_port: Optional[int] = None
|
||||
self.s: socket.socket | None = None
|
||||
self.ack: tuple[str, int] | None = None
|
||||
self.server_addr: str | None = None
|
||||
self.server_port: int | None = None
|
||||
|
||||
for source in self.sources:
|
||||
try:
|
||||
ip, port_str = source.split(':')
|
||||
ip, port_str = source.split(":")
|
||||
port = int(port_str)
|
||||
except Exception as e:
|
||||
logger.info(f"IP format error: {source}, detail: {e}")
|
||||
@@ -68,13 +67,9 @@ class ElasticClient:
|
||||
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
logger.info(
|
||||
f"Start connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}")
|
||||
sock.connect((self.server_addr, self.server_port))
|
||||
logger.info(
|
||||
f"Finish connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}")
|
||||
sock.settimeout(60)
|
||||
|
||||
self.s = sock
|
||||
@@ -83,10 +78,8 @@ class ElasticClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Connect to {source} fails, detail: {e}")
|
||||
if sock is not None:
|
||||
try:
|
||||
with suppress(Exception):
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.s = None
|
||||
self.ack = None
|
||||
self.server_addr = None
|
||||
@@ -120,10 +113,8 @@ class ElasticClient:
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
try:
|
||||
with suppress(Exception):
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send_str(self, data_str: str) -> None:
|
||||
"""
|
||||
@@ -151,8 +142,7 @@ class ElasticClient:
|
||||
data_str = self.s.recv(buffer_size).decode("utf-8")
|
||||
return data_str
|
||||
|
||||
def register(self, device_id: int, model_path: str, tp: int,
|
||||
pp: int) -> Tuple[str, int]:
|
||||
def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]:
|
||||
"""
|
||||
Registers the client with the server.
|
||||
|
||||
@@ -168,20 +158,13 @@ class ElasticClient:
|
||||
free_port = find_free_port()
|
||||
data = {
|
||||
"label": "JOIN",
|
||||
"content": {
|
||||
'device_id': device_id,
|
||||
'model_path': model_path,
|
||||
'tp': tp,
|
||||
'pp': pp,
|
||||
'port': free_port
|
||||
}
|
||||
"content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port},
|
||||
}
|
||||
|
||||
try:
|
||||
self.send_str(json.dumps(data))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Send data {data} to server fails, detail: {e}")
|
||||
raise RuntimeError(f"Send data {data} to server fails, detail: {e}")
|
||||
|
||||
try:
|
||||
ack_str = self.recv_str()
|
||||
@@ -191,23 +174,22 @@ class ElasticClient:
|
||||
try:
|
||||
ack = json.loads(ack_str)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
|
||||
)
|
||||
raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}")
|
||||
|
||||
logger.info(f"Receive ack: {ack}")
|
||||
|
||||
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
|
||||
and ack["content"] is not None and "name" in ack["content"]):
|
||||
if (
|
||||
"label" in ack
|
||||
and ack["label"] == "JOIN_ACK"
|
||||
and "content" in ack
|
||||
and ack["content"] is not None
|
||||
and "name" in ack["content"]
|
||||
):
|
||||
return (ack["content"]["name"], free_port)
|
||||
elif ("label" in ack and ack["label"] == 'JOIN_NACK'
|
||||
and "content" in ack):
|
||||
raise RuntimeError(
|
||||
f"Receive nack from server, reason: {ack['content']}")
|
||||
elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack:
|
||||
raise RuntimeError(f"Receive nack from server, reason: {ack['content']}")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Receive ack {ack} from server does not contain required fields"
|
||||
)
|
||||
raise RuntimeError(f"Receive ack {ack} from server does not contain required fields")
|
||||
|
||||
|
||||
class ElasticServer:
|
||||
@@ -215,9 +197,18 @@ class ElasticServer:
|
||||
Class for handling the server-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, addr: str, port: int, model, device_id: int,
|
||||
model_path: str, tp: int, pp: int, int8_cache: str,
|
||||
int8_cache_name: Optional[List[str]]):
|
||||
def __init__(
|
||||
self,
|
||||
addr: str,
|
||||
port: int,
|
||||
model,
|
||||
device_id: int,
|
||||
model_path: str,
|
||||
tp: int,
|
||||
pp: int,
|
||||
int8_cache: str,
|
||||
int8_cache_name: list[str] | None,
|
||||
):
|
||||
"""
|
||||
Initializes the ElasticServer instance.
|
||||
|
||||
@@ -246,30 +237,25 @@ class ElasticServer:
|
||||
self.pp = pp
|
||||
|
||||
self.original_int8 = {}
|
||||
int8_pattern = "|".join(
|
||||
map(re.escape,
|
||||
int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||
int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.dtype == torch.int8:
|
||||
if int8_cache == 'hbm':
|
||||
if int8_cache == "hbm":
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||
):
|
||||
try:
|
||||
self.original_int8[name] = param.data.clone(
|
||||
).detach()
|
||||
self.original_int8[name] = param.data.clone().detach()
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
|
||||
)
|
||||
logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}")
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
|
||||
elif int8_cache == 'dram':
|
||||
elif int8_cache == "dram":
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
int8_cache_name is not None and re.search(int8_pattern, name) is not None
|
||||
):
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
elif int8_cache == 'no':
|
||||
elif int8_cache == "no":
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -277,14 +263,18 @@ class ElasticServer:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
|
||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, "
|
||||
f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, "
|
||||
f"int8 params {list(self.original_int8)} are saved to {int8_cache}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
self.s.close()
|
||||
if self.s is not None:
|
||||
with suppress(Exception):
|
||||
self.s.close()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
@@ -343,10 +333,7 @@ class ElasticServer:
|
||||
if not all(k in content for k in required_keys):
|
||||
return False
|
||||
port = content["port"]
|
||||
if not (isinstance(port, int) or
|
||||
(isinstance(port, str) and port.isdigit())):
|
||||
return False
|
||||
return True
|
||||
return isinstance(port, int) or (isinstance(port, str) and port.isdigit())
|
||||
|
||||
comm_name = None
|
||||
if is_valid_data(data):
|
||||
@@ -355,36 +342,31 @@ class ElasticServer:
|
||||
tp = int(data["content"]["tp"])
|
||||
pp = int(data["content"]["pp"])
|
||||
|
||||
if int(self.device_id
|
||||
) == device_id and self.model_path == model_path and int(
|
||||
self.tp) == tp and int(self.pp) == pp:
|
||||
if (
|
||||
int(self.device_id) == device_id
|
||||
and self.model_path == model_path
|
||||
and int(self.tp) == tp
|
||||
and int(self.pp) == pp
|
||||
):
|
||||
comm_name = str(addr[0]) + ":" + str(addr[1])
|
||||
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
)
|
||||
server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp))
|
||||
client_desc = (device_id, model_path, tp, pp)
|
||||
msg = f"Received data {client_desc} does not consist with this server {server_desc}"
|
||||
logger.warning(msg)
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
"label": "JOIN_NACK",
|
||||
"content": msg,
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data does not contain required fields: {data}")
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data does not contain required fields: {data}"
|
||||
}
|
||||
logger.warning(f"Received data does not contain required fields: {data}")
|
||||
ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"}
|
||||
|
||||
try:
|
||||
ack_str = json.dumps(ack).encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert {ack} to JSON format, details: {e}")
|
||||
logger.error(f"Failed to convert {ack} to JSON format, details: {e}")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
@@ -395,14 +377,10 @@ class ElasticServer:
|
||||
conn.close()
|
||||
return
|
||||
|
||||
if ack["content"] and isinstance(ack["content"],
|
||||
dict) and 'name' in ack["content"]:
|
||||
if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]:
|
||||
try:
|
||||
p2psend = P2PSend(self.addr, data["content"]["port"],
|
||||
ack["content"]["name"])
|
||||
p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"])
|
||||
p2psend.send(self.model, self.original_int8)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"P2PSend Failed to send model to {self.addr}, details: {e}"
|
||||
)
|
||||
logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}")
|
||||
conn.close()
|
||||
|
||||
Reference in New Issue
Block a user