Files
xc-llm-ascend/vllm_ascend/model_loader/netloader/interaction/elastic.py
SILONG ZENG 1e3c1e76bf [Lint]Add lint hooks for clang-format, shellcheck, forbidden imports, and boolean context manager checks (#7511)
### What this PR does / why we need it?
This PR introduces several upstream `vllm`-aligned lint hooks into
`vllm-ascend` and makes them part of the actual `pre-commit` flow.

Main changes in this PR:
- add `check-boolean-context-manager` to catch boolean expressions in
`with` statements
- add `check-forbidden-imports` to forbid direct `re` imports and
disallowed direct `triton` imports
- enable shell script linting through `tools/shellcheck.sh`
- add root `.clang-format` aligned with upstream `vllm`, enable
`clang-format` in `pre-commit`, temporarily **exclude all `csrc/**`**
from `clang-format` to avoid bringing a large native code reformat into
this PR

This PR focuses on landing the smaller and immediately useful lint
alignment first, without mixing in the larger requirements-management
migration.

### Does this PR introduce _any_ user-facing change?
No.

This PR only updates repository lint configuration, static checks, and
internal import/style enforcement. It does not change runtime behavior
or public interfaces.

### How was this patch tested?
Tested locally in the project virtual environment.

Commands used:
```bash
bash format.sh
```
Verified checks passed:
``` bash
ruff check...............................................................Passed
ruff format..............................................................Passed
codespell................................................................Passed
typos....................................................................Passed
clang-format.............................................................Passed
Lint GitHub Actions workflow files.......................................Passed
Lint shell scripts.......................................................Passed
Lint PNG exports from excalidraw.........................................Passed
Check for spaces in all filenames........................................Passed
Enforce __init__.py in Python packages...................................Passed
Check for forbidden imports..............................................Passed
Check for boolean ops in with-statements.................................Passed
Suggestion...............................................................Passed
- hook id: suggestion
- duration: 0s

To bypass pre-commit hooks, add --no-verify to git commit.
```
**note:**
clang-format is enabled but currently excludes all csrc/**


- vLLM version: v0.17.0
- vLLM main:
8b6325758c

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
2026-03-24 20:03:01 +08:00

387 lines
13 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import socket
import threading
from contextlib import suppress
import regex as re
import torch
from vllm.logger import logger
from ..executor.elastic_load import P2PSend
from ..utils import find_free_port
class ElasticClient:
"""
Class for handling the client-side logic of Netloader of models.
"""
def __init__(self, sources: list[str], device_id: int, model_path: str, tp: int, pp: int):
"""
Initializes the ElasticClient instance.
Parameters:
- sources: List of source addresses in the format IP:port.
- device_id: The ID of the current device.
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
"""
self.sources = sources
self.device_id = device_id
self.model_path = model_path
self.tp = tp
self.pp = pp
self.s: socket.socket | None = None
self.ack: tuple[str, int] | None = None
self.server_addr: str | None = None
self.server_port: int | None = None
for source in self.sources:
try:
ip, port_str = source.split(":")
port = int(port_str)
except Exception as e:
logger.info(f"IP format error: {source}, detail: {e}")
continue
self.server_addr = ip
self.server_port = port
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.info(f"Start connection to server: {self.server_addr}:{self.server_port}")
sock.connect((self.server_addr, self.server_port))
logger.info(f"Finish connection to server: {self.server_addr}:{self.server_port}")
sock.settimeout(60)
self.s = sock
self.ack = self.register(device_id, model_path, tp, pp)
break
except Exception as e:
logger.error(f"Connect to {source} fails, detail: {e}")
if sock is not None:
with suppress(Exception):
sock.close()
self.s = None
self.ack = None
self.server_addr = None
self.server_port = None
def close(self) -> None:
"""
Closes the socket connection.
"""
if self.s is not None:
try:
self.s.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
finally:
self.s = None
def __enter__(self) -> "ElasticClient":
"""
Context manager enter method.
"""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""
Context manager exit method.
"""
self.close()
def __del__(self):
"""
Destructor method to ensure socket is closed.
"""
with suppress(Exception):
self.close()
def send_str(self, data_str: str) -> None:
"""
Sends a string over the socket connection.
Parameters:
- data_str: The string to be sent.
"""
if self.s is None:
raise RuntimeError("Socket was not created correctly.")
self.s.send(data_str.encode("utf-8"))
def recv_str(self, buffer_size: int = 1024) -> str:
"""
Receives a string over the socket connection.
Parameters:
- buffer_size: The size of the buffer for receiving data.
Returns:
- The received string.
"""
if self.s is None:
raise RuntimeError("Socket was not created correctly.")
data_str = self.s.recv(buffer_size).decode("utf-8")
return data_str
def register(self, device_id: int, model_path: str, tp: int, pp: int) -> tuple[str, int]:
"""
Registers the client with the server.
Parameters:
- device_id: The ID of the current device.
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
Returns:
- A tuple containing the communication name and port.
"""
free_port = find_free_port()
data = {
"label": "JOIN",
"content": {"device_id": device_id, "model_path": model_path, "tp": tp, "pp": pp, "port": free_port},
}
try:
self.send_str(json.dumps(data))
except Exception as e:
raise RuntimeError(f"Send data {data} to server fails, detail: {e}")
try:
ack_str = self.recv_str()
except Exception as e:
raise RuntimeError(f"Receive data from server fails, detail: {e}")
try:
ack = json.loads(ack_str)
except Exception as e:
raise RuntimeError(f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}")
logger.info(f"Receive ack: {ack}")
if (
"label" in ack
and ack["label"] == "JOIN_ACK"
and "content" in ack
and ack["content"] is not None
and "name" in ack["content"]
):
return (ack["content"]["name"], free_port)
elif "label" in ack and ack["label"] == "JOIN_NACK" and "content" in ack:
raise RuntimeError(f"Receive nack from server, reason: {ack['content']}")
else:
raise RuntimeError(f"Receive ack {ack} from server does not contain required fields")
class ElasticServer:
"""
Class for handling the server-side logic of Netloader of models.
"""
def __init__(
self,
addr: str,
port: int,
model,
device_id: int,
model_path: str,
tp: int,
pp: int,
int8_cache: str,
int8_cache_name: list[str] | None,
):
"""
Initializes the ElasticServer instance.
Parameters:
- addr: The IP address to listen on.
- port: The port number to listen on.
- model: The model to be served.
- device_id: The ID of the current device (i.e. global rank).
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
- int8_cache: The type of caching for int8 parameters (HBM, DRAM, or no).
- int8_cache_name: List of parameter names to be cached.
"""
self.addr = addr
self.port = port
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.s.bind((self.addr, self.port))
self.s.listen(256)
self.model = model
self.device_id = device_id
self.model_path = model_path
self.tp = tp
self.pp = pp
self.original_int8 = {}
int8_pattern = "|".join(map(re.escape, int8_cache_name)) if int8_cache_name is not None else "(?:)"
for name, param in self.model.named_parameters():
if param.dtype == torch.int8:
if int8_cache == "hbm":
if int8_cache_name is None or (
int8_cache_name is not None and re.search(int8_pattern, name) is not None
):
try:
self.original_int8[name] = param.data.clone().detach()
except RuntimeError as e:
logger.error(f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}")
self.original_int8[name] = param.data.cpu()
elif int8_cache == "dram":
if int8_cache_name is None or (
int8_cache_name is not None and re.search(int8_pattern, name) is not None
):
self.original_int8[name] = param.data.cpu()
elif int8_cache == "no":
pass
else:
logger.warning(
f"int8_cache should be selected in [HBM, DRAM], but got {int8_cache}, change to no cache"
)
logger.info(
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, "
f"model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, "
f"int8 params {list(self.original_int8)} are saved to {int8_cache}"
)
def __del__(self):
"""
Destructor method to ensure socket is closed.
"""
if self.s is not None:
with suppress(Exception):
self.s.close()
def start(self):
"""
Starts the server to handle incoming connections.
"""
handler_thread = threading.Thread(target=self.elastic_client_handler)
handler_thread.daemon = True
handler_thread.start()
def elastic_client_handler(self):
"""
Handles incoming client connections.
"""
while True:
conn, addr = self.s.accept()
logger.info("Accept new connection from {}:{}...".format(*addr))
self.register_handler(conn, addr)
def register_handler(self, conn, addr, buffer_size=1024):
"""
Handles the registration of a client.
Parameters:
- conn: The connection socket.
- addr: The address of the client.
- buffer_size: The size of the buffer for receiving data.
"""
data_str = conn.recv(buffer_size).decode("utf-8")
if not data_str:
return
try:
data = json.loads(data_str)
except Exception:
logger.error(f"Failed to load {data_str} as JSON string")
conn.close()
return
def is_valid_data(data):
"""
Validates the received data.
Parameters:
- data: The data to be validated.
Returns:
- True if the data is valid, otherwise False.
"""
if not isinstance(data, dict):
return False
if data.get("label") != "JOIN":
return False
content = data.get("content")
if not isinstance(content, dict):
return False
required_keys = ["device_id", "model_path", "tp", "pp", "port"]
if not all(k in content for k in required_keys):
return False
port = content["port"]
return isinstance(port, int) or (isinstance(port, str) and port.isdigit())
comm_name = None
if is_valid_data(data):
device_id = int(data["content"]["device_id"])
model_path = data["content"]["model_path"]
tp = int(data["content"]["tp"])
pp = int(data["content"]["pp"])
if (
int(self.device_id) == device_id
and self.model_path == model_path
and int(self.tp) == tp
and int(self.pp) == pp
):
comm_name = str(addr[0]) + ":" + str(addr[1])
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
else:
server_desc = (int(self.device_id), self.model_path, int(self.tp), int(self.pp))
client_desc = (device_id, model_path, tp, pp)
msg = f"Received data {client_desc} does not consist with this server {server_desc}"
logger.warning(msg)
ack = {
"label": "JOIN_NACK",
"content": msg,
}
else:
logger.warning(f"Received data does not contain required fields: {data}")
ack = {"label": "JOIN_NACK", "content": f"Received data does not contain required fields: {data}"}
try:
ack_str = json.dumps(ack).encode("utf-8")
except Exception as e:
logger.error(f"Failed to convert {ack} to JSON format, details: {e}")
conn.close()
return
try:
conn.send(ack_str)
except Exception as e:
logger.error(f"Failed to send {ack} to {addr}, details: {e}")
conn.close()
return
if ack["content"] and isinstance(ack["content"], dict) and "name" in ack["content"]:
try:
p2psend = P2PSend(self.addr, data["content"]["port"], ack["content"]["name"])
p2psend.send(self.model, self.original_int8)
except Exception as e:
logger.error(f"P2PSend Failed to send model to {self.addr}, details: {e}")
conn.close()