[Lint]Style: Convert example to ruff format (#5863)

### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
SILONG ZENG
2026-01-13 20:46:50 +08:00
committed by GitHub
parent f7b904641e
commit 78d5ce3e01
23 changed files with 678 additions and 1037 deletions

View File

@@ -4,7 +4,6 @@ default_install_hook_types:
default_stages:
- pre-commit # Run locally
- manual # Run in CI
exclude: 'examples/.*' # Exclude examples from all hooks by default
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.0

View File

@@ -91,10 +91,8 @@ import heapq
import ipaddress
import os
import sys
import threading
import uuid
from contextlib import asynccontextmanager
from typing import List
import httpx
from fastapi import FastAPI, Request
@@ -106,28 +104,28 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.url = f"http://{host}:{port}/v1"
try:
ip = ipaddress.ip_address(self.host)
if isinstance(ip, ipaddress.IPv6Address):
self.url = f'http://[{host}]:{port}/v1'
self.url = f"http://[{host}]:{port}/v1"
except Exception:
pass
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
@@ -136,14 +134,9 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
@@ -151,10 +144,8 @@ class ProxyState:
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
self.req_id_future = {}
@@ -166,23 +157,18 @@ class ProxyState:
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
@@ -190,8 +176,7 @@ class ProxyState:
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
@@ -272,37 +257,20 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
raise ValueError("Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@@ -310,11 +278,8 @@ def parse_args():
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
yield
for p in proxy_state.prefillers:
await p.client.aclose()
@@ -331,14 +296,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -351,15 +314,16 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
async def send_request_to_service(
client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
req_data = req_data.copy()
req_data["stream"] = False
req_data["max_tokens"] = 1
@@ -368,49 +332,38 @@ async def send_request_to_service(client: httpx.AsyncClient,
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(response.json()["kv_transfer_params"])
return
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -419,32 +372,22 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
@@ -469,15 +412,11 @@ async def _handle_completions(api: str, request: Request):
request_length = len(req_body)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
api)
req_data['kv_transfer_params'] = {
"do_remote_decode":
False,
"do_remote_prefill":
True,
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver",
}
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
@@ -494,28 +433,30 @@ async def _handle_completions(api: str, request: Request):
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {decoder.url}: {str(e)} "
f"the aborted request {request_id} will be routing to the target "
"prefiller when new request is ready to dispatch to it"
)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -538,7 +479,7 @@ async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
"decode_instances": len(proxy_state.decoders),
}
@@ -553,25 +494,24 @@ async def metaserver(request: Request):
request_id = get_origin_request_id(api, request_id)
req_data["kv_transfer_params"] = kv_transfer_params
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
# Send request to prefiller
response = await send_request_to_service(
await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
base_delay=global_args.retry_delay,
)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score)
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
@@ -579,8 +519,9 @@ async def metaserver(request: Request):
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -125,7 +125,7 @@ import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict
from typing import Any
import httpx
from fastapi import FastAPI, Request
@@ -150,22 +150,21 @@ class InstanceType:
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.url = f"http://{host}:{port}/v1"
try:
ip = ipaddress.ip_address(self.host)
if isinstance(ip, ipaddress.IPv6Address):
self.url = f'http://[{host}]:{port}/v1'
self.url = f"http://[{host}]:{port}/v1"
except Exception:
pass
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
@@ -186,16 +185,11 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.node_listener = NodeListener(self)
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
@@ -203,10 +197,8 @@ class ProxyState:
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
@@ -216,23 +208,18 @@ class ProxyState:
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
@@ -240,8 +227,7 @@ class ProxyState:
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
@@ -314,9 +300,7 @@ class ProxyState:
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
async def add_instances(
self, instance_type: str, instances: List[ServerState]
) -> Tuple[List[str], List[str]]:
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
added_nodes, waiting_nodes = [], []
for server in instances:
is_valid = await self.node_listener.check_instance_status(server.client)
@@ -332,7 +316,7 @@ class ProxyState:
waiting_nodes.append(node)
return added_nodes, waiting_nodes
def add_prefillers(self, instances: List[ServerState]) -> None:
def add_prefillers(self, instances: list[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
@@ -341,7 +325,7 @@ class ProxyState:
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
self.print_status(f"Add prefiller instances: {instances}.")
def add_decoders(self, instances: List[ServerState]) -> None:
def add_decoders(self, instances: list[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
@@ -350,7 +334,7 @@ class ProxyState:
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
self.print_status(f"Add decoder instances: {instances}.")
def remove_prefillers(self, instances: List[ServerState]) -> None:
def remove_prefillers(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
@@ -367,7 +351,7 @@ class ProxyState:
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
def remove_decoders(self, instances: List[ServerState]) -> None:
def remove_decoders(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
@@ -387,7 +371,7 @@ class ProxyState:
def print_status(self, msg: str) -> None:
status = {
"prefill_instances": [str(server) for server in self.prefillers],
"decode_instances": [str(server) for server in self.decoders]
"decode_instances": [str(server) for server in self.decoders],
}
print(f"{msg} Status: {status}")
@@ -398,7 +382,7 @@ proxy_state = None
class NodeListener:
def __init__(self, proxy):
self.proxy_state = proxy
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
self.listening_thread.start()
@@ -424,9 +408,7 @@ class NodeListener:
@staticmethod
async def check_instance_status(client: httpx.AsyncClient) -> bool:
endpoint = "/models"
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
try:
response = await client.get(endpoint, headers=headers)
response.raise_for_status()
@@ -439,46 +421,29 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
parser.add_argument("--max-waiting-retries",
type=int,
default=3,
help="Maximum number of retries for waiting nodes to be started")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
parser.add_argument(
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
)
parser.add_argument(
"--waiting-retry-interval",
type=float,
default=10,
help="Check interval (seconds) for waiting nodes to be started")
help="Check interval (seconds) for waiting nodes to be started",
)
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
raise ValueError("Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@@ -486,11 +451,8 @@ def parse_args():
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
yield
for p in proxy_state.prefillers:
await p.client.aclose()
@@ -507,14 +469,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -527,17 +487,18 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
async def send_request_to_service(
client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
@@ -553,46 +514,35 @@ async def send_request_to_service(client: httpx.AsyncClient,
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -601,41 +551,28 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
@@ -648,10 +585,11 @@ async def _handle_select_instance(api: str, req_data: Any,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
base_delay=global_args.retry_delay,
)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
@@ -661,13 +599,15 @@ async def _handle_select_instance(api: str, req_data: Any,
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
return InstanceInfo(
request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score,
)
@dataclass
@@ -686,8 +626,7 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data,
request_length)
instance_info = await _handle_select_instance(api, req_data, request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
@@ -713,34 +652,31 @@ async def _handle_completions(api: str, request: Request):
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
@@ -751,63 +687,52 @@ async def _handle_completions(api: str, request: Request):
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content
stop_reason = choice.get(
"stop_reason")
stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
messages[0]["content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"][
"content"] = generated_token
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
f"the aborted request {instance_info.request_id} will be routing to the target "
"prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -821,20 +746,21 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
if isinstance(instances, str):
instances = [instances]
instances = trans_instances(instances)
all_msg = f"{adjust_mode} {instance_type} instances: " \
f"{[str(server) for server in instances]}."
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
return {"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
return {
"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
}
if adjust_mode == "add":
added_nodes, waiting_nodes = await proxy_state.add_instances(
instance_type, instances
)
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
if waiting_nodes:
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
f"Instances {waiting_nodes} are waiting to be added."
all_msg = (
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
f"Instances {waiting_nodes} are waiting to be added."
)
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
@@ -843,14 +769,14 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
}
except Exception as e:
logger.error(f"Failed to {adjust_mode} instances: {e}")
raise e
def trans_instances(instances: List[str]) -> List[ServerState]:
def trans_instances(instances: list[str]) -> list[ServerState]:
server_list = []
for instance in instances:
h, p = instance.split(":")
@@ -875,7 +801,7 @@ async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
"decode_instances": len(proxy_state.decoders),
}
@@ -889,7 +815,7 @@ async def handle_remove_instances(request: Request):
return await _handle_adjust_instances("remove", request)
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn

View File

@@ -2,17 +2,17 @@
## Environmental Dependencies
* Software:
* Python >= 3.10, < 3.12
* CANN == 8.3.rc2
* PyTorch == 2.8.0, torch-npu == 2.8.0
* vLLM (same version as vllm-ascend)
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
* Software:
* Python >= 3.10, < 3.12
* CANN == 8.3.rc2
* PyTorch == 2.8.0, torch-npu == 2.8.0
* vLLM (same version as vllm-ascend)
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is
* vllm: v0.10.1
* vllm-ascend: v0.10.1rc1
* vllm: v0.10.1
* vllm-ascend: v0.10.1rc1
## run
@@ -84,7 +84,6 @@ Set `GLOO_SOCKET_IFNAME`, `TP_SOCKET_IFNAME`, and `HCCL_SOCKET_IFNAME` to the co
`--gpu-memory-utilization`: Percentage of video memory occupied by the card<br>
`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.<br>
### 2. Run `decode` Node
```
@@ -151,7 +150,6 @@ python load_balance_proxy_server_example.py --host localhost --prefiller-hosts h
`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.<br>
### 4. Run Inference
Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script.
@@ -162,4 +160,4 @@ curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json"
"prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?",
"max_tokens": 256
}'
```
```

View File

@@ -4,13 +4,11 @@ Expert parallelism load balancer (EPLB) for vLLM.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
"""
from typing import Tuple
import torch
def balanced_packing(weight: torch.Tensor,
num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
def balanced_packing(weight: torch.Tensor, num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
@@ -18,8 +16,8 @@ def balanced_packing(weight: torch.Tensor,
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
@@ -28,26 +26,18 @@ def balanced_packing(weight: torch.Tensor,
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = torch.arange(weight.size(-1),
dtype=torch.int64,
device=weight.device).expand(weight.shape)
pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight,
fill_value=-1,
dtype=torch.int64,
device='cpu')
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i
for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__)
pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
@@ -56,16 +46,14 @@ def balanced_packing(weight: torch.Tensor,
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor,
num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def replicate_experts(weight: torch.Tensor, num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
@@ -75,8 +63,7 @@ def replicate_experts(
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64,
device=device).repeat(n, 1)
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
@@ -88,9 +75,9 @@ def replicate_experts(
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(weight: torch.Tensor,
num_physical_experts: int, num_groups: int,
num_nodes: int, num_gpus: int):
def rebalance_experts_hierarchical(
weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int
):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
@@ -99,7 +86,7 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
@@ -115,45 +102,37 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1, perm,
torch.arange(perm.size(1), dtype=torch.int64,
device=perm.device).expand(perm.shape))
inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(
tokens_per_group, num_nodes)
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
group_size).unsqueeze(-1) +
torch.arange(group_size,
dtype=torch.int64,
device=group_pack_index.device)).flatten(-2)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes)
tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
num_gpus // num_nodes)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, device=group_pack_index.device).view(
1, -1, 1
)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
@@ -161,9 +140,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
def rebalance_experts(
weight: torch.Tensor, num_replicas: int, num_groups: int,
num_nodes: int,
num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
@@ -174,7 +152,7 @@ def rebalance_experts(
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
@@ -183,23 +161,20 @@ def rebalance_experts(
weight = weight.float().cpu()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus)
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus)
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device)
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=torch.int64, device=logcnt.device
)
log2phy.view(num_layers, -1).scatter_(
-1, phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64,
device=log2phy.device).expand(num_layers, -1))
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1),
)
return phy2log, log2phy, logcnt
__all__ = ['rebalance_experts']
__all__ = ["rebalance_experts"]

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import json
import os
@@ -21,10 +20,7 @@ def save_matrix_to_json(output_path, file_name, deployment):
layer = {"layer_id": i, "device_count": num_cards}
device_list = []
for j in range(num_cards):
device = {
"device_id": j,
"device_expert": deployment[i, j].tolist()
}
device = {"device_id": j, "device_expert": deployment[i, j].tolist()}
device_list.append(device)
layer["device_list"] = device_list
layer_list.append(layer)
@@ -34,7 +30,7 @@ def save_matrix_to_json(output_path, file_name, deployment):
# Save as JSON file
try:
with open(file_name, 'w') as f:
with open(file_name, "w") as f:
json.dump(data, f, indent=4)
except Exception as e:
print(f"write {file_name} failed: {e}")
@@ -63,21 +59,17 @@ def calculate_average(lst):
return total / count
def layer_imblance_polt(y_list, label_names, device_num, output_path,
file_name):
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['axes.unicode_minus'] = False
def layer_imblance_polt(y_list, label_names, device_num, output_path, file_name):
plt.rcParams["font.sans-serif"] = ["Arial"]
plt.rcParams["axes.unicode_minus"] = False
x = [i for i in range(58)]
for index, y in enumerate(y_list):
plt.plot(x,
y,
label=rf'{label_names[index]}avg={calculate_average(y)}')
plt.plot(x, y, label=rf"{label_names[index]}avg={calculate_average(y)}")
plt.legend()
plt.title(rf'Load Distribution (num_gpus={device_num})')
plt.xlabel('layer')
plt.ylabel('Device Load')
plt.title(rf"Load Distribution (num_gpus={device_num})")
plt.xlabel("layer")
plt.ylabel("Device Load")
# Show grid lines
plt.grid(True)
@@ -88,27 +80,23 @@ def layer_imblance_polt(y_list, label_names, device_num, output_path,
plt.close()
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
num_gpus, num_original_expert):
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes, num_gpus, num_original_expert):
from eplb_deepseek import rebalance_experts
num_replicas = num_original_expert + num_redundancy_expert
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas,
num_groups, num_nodes,
num_gpus)
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas, num_groups, num_nodes, num_gpus)
# Convert to global_deployment
workload = workload.cpu().numpy()
global_deployment = []
layer_num = log2phy.shape[0]
num_physical_experts_local = (num_original_expert +
num_redundancy_expert) // num_gpus
num_physical_experts_local = (num_original_expert + num_redundancy_expert) // num_gpus
for layer_idx in range(layer_num):
layer_deployment = []
for gpu_idx in range(num_gpus):
local_deployment = hy2log[layer_idx][gpu_idx *
num_physical_experts_local:
(gpu_idx + 1) *
num_physical_experts_local]
local_deployment = hy2log[layer_idx][
gpu_idx * num_physical_experts_local : (gpu_idx + 1) * num_physical_experts_local
]
local_deployment = local_deployment.flatten()
layer_deployment.append(local_deployment.tolist())
global_deployment.append(layer_deployment)
@@ -122,18 +110,15 @@ def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
new_value = workload[layer_idx].reshape(num_gpus, -1)
row_sum = np.sum(new_value, axis=1)
original_weights.append(row_sum.max())
average_weights.append((np.sum(workload[layer_idx]) / num_gpus))
average_weights.append(np.sum(workload[layer_idx]) / num_gpus)
opt_workload = np.zeros((num_original_expert + num_redundancy_expert),
dtype=np.float64)
opt_workload = np.zeros((num_original_expert + num_redundancy_expert), dtype=np.float64)
for expert_idx in range(num_original_expert):
physical_expert_idxs = log2phy[layer_idx][expert_idx]
physical_expert_idxs = physical_expert_idxs.flatten()
physical_expert_idxs = physical_expert_idxs[
physical_expert_idxs != -1]
physical_expert_idxs = physical_expert_idxs[physical_expert_idxs != -1]
for physical_expert_idx in physical_expert_idxs:
opt_workload[physical_expert_idx] += workload[layer_idx][
expert_idx] / len(physical_expert_idxs)
opt_workload[physical_expert_idx] += workload[layer_idx][expert_idx] / len(physical_expert_idxs)
opt_workload = opt_workload.reshape(num_gpus, -1)
row_sum = np.sum(opt_workload, axis=1)
max_weights.append(row_sum.max())
@@ -142,8 +127,9 @@ def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
return global_deployment, y_list
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--exp_name", type=str, default="gsm8k_temp0.0")
parser.add_argument("--num_original_expert", type=int, default=256)
@@ -165,19 +151,13 @@ if __name__ == '__main__':
num_nodes = args.num_nodes
# NOTE: assume input workload format: [layer_num, num_experts]
workload = torch.load(input_path, map_location=torch.device('cpu'))
global_deployment, y_list = deepseek_deploy(workload,
num_redundancy_expert,
num_groups, num_nodes,
num_devices,
num_original_expert)
workload = torch.load(input_path, map_location=torch.device("cpu"))
global_deployment, y_list = deepseek_deploy(
workload, num_redundancy_expert, num_groups, num_nodes, num_devices, num_original_expert
)
file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}"
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
label_names = [
'default deployment max load', 'balanced load max load',
'balanced load avg load'
]
label_names = ["default deployment max load", "balanced load max load", "balanced load avg load"]
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
layer_imblance_polt(y_list, label_names, num_devices, output_path,
new_file_name)
layer_imblance_polt(y_list, label_names, num_devices, output_path, new_file_name)

View File

@@ -18,6 +18,7 @@ All the arguments that can be set by users are:
7. `--vllm-start-port`: Starting port of vLLM serving instances, default 9000
An example of running external DP in one single node:
```(python)
cd examples/external_online_dp
# running DP4 TP4 in a node with 16 NPUs
@@ -25,6 +26,7 @@ python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 4 --dp-rank-s
```
An example of running external DP in two nodes:
```(python)
cd examples/external_online_dp
# running DP4 TP4 in two nodes with 8 NPUs each

View File

@@ -23,7 +23,7 @@
# ----------------------------------
# You need to have at least two vLLM servers running in data parallel.
# These can be mock servers or actual vLLM servers.
# Note that this proxy also works with only one vLLM server running, but
# Note that this proxy also works with only one vLLM server running, but
# will fall back to direct request forwarding which is meaningless.
#
# For testing, you can use the provided mock server:
@@ -84,13 +84,12 @@ import argparse
import asyncio
import functools
import heapq
import json
import os
import sys
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List
from typing import Any
import httpx
from fastapi import FastAPI, Request
@@ -109,34 +108,29 @@ except ImportError:
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.url = f"http://{host}:{port}/v1"
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.aborted_requests = set() # Track aborted requests
class ProxyState:
def __init__(self, server_instances):
self.dp_servers: List[ServerState] = [
ServerState(h, p) for h, p in server_instances
]
self.dp_servers: list[ServerState] = [ServerState(h, p) for h, p in server_instances]
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.lb_heap = [(0, i, server)
for i, server in enumerate(self.dp_servers)]
self.lb_heap = [(0, i, server) for i, server in enumerate(self.dp_servers)]
heapq.heapify(self.lb_heap)
def _update_server_priority(self, server_idx: int):
@@ -144,10 +138,8 @@ class ProxyState:
server = self.dp_servers[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap
if i != server_idx]
heapq.heappush(self.lb_heap,
(priority, server_idx, server)) # type: ignore
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap if i != server_idx]
heapq.heappush(self.lb_heap, (priority, server_idx, server)) # type: ignore
async def next_req_id(self):
async with self.req_id_lock:
@@ -190,27 +182,15 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--dp-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--dp-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--dp-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--dp-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
args = parser.parse_args()
if len(args.dp_hosts) != len(args.dp_ports):
raise ValueError(
"Number of dp hosts must match number of dp ports")
raise ValueError("Number of dp hosts must match number of dp ports")
args.server_instances = list(zip(args.dp_hosts, args.dp_ports))
return args
@@ -219,9 +199,7 @@ def parse_args():
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.server_instances)
print(
f"Initialized {len(proxy_state.dp_servers)} dp server clients."
)
print(f"Initialized {len(proxy_state.dp_servers)} dp server clients.")
yield
for p in proxy_state.dp_servers:
await p.client.aclose()
@@ -236,14 +214,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -256,22 +232,18 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -280,53 +252,42 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
async def _select_instance(api: str, req_data: Any,
request_length: int):
async def _select_instance(api: str, req_data: Any, request_length: int):
# refer to vLLM sampling_params: max_token default value
max_tokens = req_data.get("max_tokens", 16)
ignore_eos = req_data.get("ignore_eos", False)
priority_score = proxy_state.calculate_request_score(request_length,max_tokens=max_tokens, ignore_eos=ignore_eos)
priority_score = proxy_state.calculate_request_score(request_length, max_tokens=max_tokens, ignore_eos=ignore_eos)
logger.debug(
f"Request length: {request_length}, max tokens: {max_tokens}, ignore_eos: {ignore_eos}, Priority score: {priority_score}"
f"Request length: {request_length}, max tokens: {max_tokens}, "
f"ignore_eos: {ignore_eos}, Priority score: {priority_score}"
)
request_id = await proxy_state.next_req_id()
# Select dp server based on priority score
server_idx = proxy_state.select_server(priority_score)
choosen_server = proxy_state.dp_servers[server_idx]
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
return InstanceInfo(request_id=request_id,
server_idx=server_idx,
priority_score=priority_score,
server_state=choosen_server)
return InstanceInfo(
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=choosen_server
)
@dataclass
@@ -342,36 +303,36 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _select_instance(api, req_data,
request_length)
instance_info = await _select_instance(api, req_data, request_length)
async def generate_stream():
nonlocal instance_info
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
instance_info.server_state.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
instance_info.server_state.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, the aborted request is: {instance_info.request_id}."
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, "
f"the aborted request is: {instance_info.request_id}."
)
# After streaming done, release tokens
proxy_state.release_server(instance_info.server_idx,
instance_info.priority_score)
proxy_state.release_server(instance_info.server_idx, instance_info.priority_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in external dp proxy server"
f" - {api} endpoint")
print(f"Error occurred in external dp proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -397,7 +358,7 @@ async def healthcheck():
}
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn

View File

@@ -4,52 +4,19 @@ import os
import subprocess
import sys
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dp-size",
type=int,
required=True,
help="Data parallel size."
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size."
)
parser.add_argument(
"--dp-size-local",
type=int,
default=-1,
help="Local data parallel size."
)
parser.add_argument(
"--dp-rank-start",
type=int,
default=0,
help="Starting rank for data parallel."
)
parser.add_argument(
"--dp-address",
type=str,
required=True,
help="IP address for data parallel master node."
)
parser.add_argument(
"--dp-rpc-port",
type=str,
default=12345,
help="Port for data parallel master node."
)
parser.add_argument(
"--vllm-start-port",
type=int,
default=9000,
help="Starting port for the engine."
)
parser.add_argument("--dp-size", type=int, required=True, help="Data parallel size.")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size.")
parser.add_argument("--dp-size-local", type=int, default=-1, help="Local data parallel size.")
parser.add_argument("--dp-rank-start", type=int, default=0, help="Starting rank for data parallel.")
parser.add_argument("--dp-address", type=str, required=True, help="IP address for data parallel master node.")
parser.add_argument("--dp-rpc-port", type=str, default=12345, help="Port for data parallel master node.")
parser.add_argument("--vllm-start-port", type=int, default=9000, help="Starting port for the engine.")
return parser.parse_args()
args = parse_args()
dp_size = args.dp_size
tp_size = args.tp_size
@@ -61,6 +28,7 @@ dp_address = args.dp_address
dp_rpc_port = args.dp_rpc_port
vllm_start_port = args.vllm_start_port
def run_command(visiable_devices, dp_rank, vllm_engine_port):
command = [
"bash",
@@ -75,6 +43,7 @@ def run_command(visiable_devices, dp_rank, vllm_engine_port):
]
subprocess.run(command, check=True)
if __name__ == "__main__":
template_path = "./run_dp_template.sh"
if not os.path.exists(template_path):
@@ -87,11 +56,9 @@ if __name__ == "__main__":
dp_rank = dp_rank_start + i
vllm_engine_port = vllm_start_port + i
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
process = multiprocessing.Process(target=run_command,
args=(visiable_devices, dp_rank,
vllm_engine_port))
process = multiprocessing.Process(target=run_command, args=(visiable_devices, dp_rank, vllm_engine_port))
processes.append(process)
process.start()
for process in processes:
process.join()
process.join()

View File

@@ -61,13 +61,13 @@ from time import sleep
import torch
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel # noqa E402
from vllm.utils.network_utils import get_open_port
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def parse_args():
import argparse
@@ -78,43 +78,18 @@ def parse_args():
default="ibm-research/PowerMoE-3b",
help="Model name or path",
)
parser.add_argument("--dp-size",
type=int,
default=2,
help="Data parallel size")
parser.add_argument("--tp-size",
type=int,
default=1,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action="store_true",
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action="store_true",
help="Trust remote code.")
parser.add_argument("--enable-expert-parallel",
action="store_true",
help="Enable expert parallel, used in MOE models.")
parser.add_argument("--quantization",
type=str,
default="",
help="Use quantization models")
parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
parser.add_argument(
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
)
parser.add_argument("--quantization", type=str, default="", help="Use quantization models")
return parser.parse_args()
@@ -127,6 +102,7 @@ def cleanup_env_and_memory():
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
def main(
model,
dp_size,
@@ -168,7 +144,7 @@ def main(
def start(rank):
return rank * floor + min(rank, remainder)
prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)]
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
@@ -179,9 +155,7 @@ def main(
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=[16, 20][global_dp_rank % 2])
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2])
# Create an LLM.
llm = LLM(
@@ -200,14 +174,14 @@ def main(
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Give engines time to pause their processing loops before exiting.
sleep(5)
del llm
cleanup_env_and_memory()
if __name__ == "__main__":
args = parse_args()
@@ -231,8 +205,7 @@ if __name__ == "__main__":
from multiprocessing import Process
procs = []
for local_dp_rank, global_dp_rank in enumerate(
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
for local_dp_rank, global_dp_rank in enumerate(range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
proc = Process(
target=main,
args=(
@@ -255,9 +228,7 @@ if __name__ == "__main__":
for proc in procs:
proc.join(timeout=900)
if proc.exitcode is None:
print(
f"Killing process {proc.pid} that didn't stop within 15 minutes."
)
print(f"Killing process {proc.pid} that didn't stop within 15 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:

View File

@@ -29,8 +29,8 @@ def clean_up():
import gc
import torch
from vllm.distributed.parallel_state import (
destroy_distributed_environment, destroy_model_parallel)
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
@@ -44,8 +44,10 @@ def run_prefill(prefill_done, process_close):
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
"Hello, how are you today?",
"Hi, what is your name?",
"Tell me a very long story.",
"what is your favourite book?",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
@@ -55,22 +57,16 @@ def run_prefill(prefill_done, process_close):
kv_port="30000",
engine_id="0",
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
kv_connector_extra_config={
"prefill": {
"dp_size": 1,
"tp_size": 1
},
"decode": {
"dp_size": 1,
"tp_size": 1
}
})
kv_connector_extra_config={"prefill": {"dp_size": 1, "tp_size": 1}, "decode": {"dp_size": 1, "tp_size": 1}},
)
# Set NPU memory utilization to 0.8
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=1)
llm = LLM(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=1,
)
llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
@@ -96,8 +92,10 @@ def run_decode(prefill_done):
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
"Hello, how are you today?",
"Hi, what is your name?",
"Tell me a very long story.",
"what is your favourite book?",
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)
@@ -107,22 +105,16 @@ def run_decode(prefill_done):
kv_port="30100",
engine_id="1",
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
kv_connector_extra_config={
"prefill": {
"dp_size": 1,
"tp_size": 1
},
"decode": {
"dp_size": 1,
"tp_size": 1
}
})
kv_connector_extra_config={"prefill": {"dp_size": 1, "tp_size": 1}, "decode": {"dp_size": 1, "tp_size": 1}},
)
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=1)
llm = LLM(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=1,
)
# Wait for the producer to start the consumer
print("Waiting for prefill node to finish...")
@@ -141,16 +133,18 @@ def run_decode(prefill_done):
if __name__ == "__main__":
mp.get_context('spawn')
mp.get_context("spawn")
prefill_done = Event()
process_close = Event()
prefill_process = Process(target=run_prefill,
args=(
prefill_done,
process_close,
))
decode_process = Process(target=run_decode, args=(prefill_done, ))
prefill_process = Process(
target=run_prefill,
args=(
prefill_done,
process_close,
),
)
decode_process = Process(target=run_decode, args=(prefill_done,))
# Start prefill node
prefill_process.start()

View File

@@ -25,22 +25,24 @@ from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
return f"Instruct: {task_description}\nQuery:{query}"
def main():
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
task = "Given a web search query, retrieve relevant passages that answer the query"
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
get_detailed_instruct(task, "What is the capital of China?"),
get_detailed_instruct(task, "Explain gravity"),
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
"Gravity is a force that attracts two bodies towards each other. "
"It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
input_texts = queries + documents
@@ -49,7 +51,7 @@ def main():
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
# Calculate the similarity scores between the first two queries and the last two documents
scores = (embeddings[:2] @ embeddings[2:].T)
scores = embeddings[:2] @ embeddings[2:].T
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]

View File

@@ -63,10 +63,13 @@ from multiprocessing import Process
from time import sleep
import torch
from safetensors.torch import load_file
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
from safetensors.torch import load_file
destroy_distributed_environment,
destroy_model_parallel,
get_tp_group,
)
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_open_port
@@ -101,7 +104,6 @@ def load_and_merge_safetensors(directory):
def parse_args():
parser = argparse.ArgumentParser(description="External launcher Inference")
parser.add_argument(
"--model",
@@ -109,60 +111,41 @@ def parse_args():
default="Qwen/Qwen3-0.6B",
help="Model name or path",
)
parser.add_argument("--tp-size",
type=int,
default=1,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--proc-per-node",
type=int,
default=1,
help="Number of processes per node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action="store_true",
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action="store_true",
help="Trust remote code.")
parser.add_argument("--enable-expert-parallel",
action="store_true",
help="Enable expert parallel, used in MOE models.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
help="Enable sleep mode for the engine.")
parser.add_argument("--temperature",
type=float,
default=0.8,
help="Float that controls the randomness of the sampling.")
parser.add_argument("--model-weight-gib",
type=float,
default=None,
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
parser.add_argument("--sleep-mode-level",
type=int,
choices=[1, 2],
default=1,
help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
parser.add_argument("--proc-per-node", type=int, default=1, help="Number of processes per node")
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
parser.add_argument(
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
)
parser.add_argument("--enable-sleep-mode", action="store_true", help="Enable sleep mode for the engine.")
parser.add_argument(
"--temperature", type=float, default=0.8, help="Float that controls the randomness of the sampling."
)
parser.add_argument(
"--model-weight-gib",
type=float,
default=None,
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).",
)
parser.add_argument(
"--sleep-mode-level",
type=int,
choices=[1, 2],
default=1,
help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.",
)
args = parser.parse_args()
if args.enable_sleep_mode:
if args.model_weight_gib is None or args.temperature != 0:
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
parser.error(
"model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set."
)
if args.model_weight_gib <= 0:
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
if args.model == parser.get_default("model") and args.model_weight_gib is None:
@@ -220,7 +203,7 @@ def main(
enable_sleep_mode=enable_sleep_mode,
)
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
print(f"TP RANKS: {tp_ranks}")
outputs = llm.generate(prompts, sampling_params)
@@ -231,7 +214,7 @@ def main(
if rank == 0:
free_bytes_after_sleep, total = torch.npu.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
print(f"Freed memory: {freed_bytes / 1024**3:.2f} GiB")
# now the freed memory should be larger than the model weights
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
@@ -257,8 +240,7 @@ def main(
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
print(f"Global rank: {rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Give engines time to pause their processing loops before exiting.
sleep(5)
@@ -294,25 +276,26 @@ if __name__ == "__main__":
world_size = node_size * proc_per_node
procs = []
for local_rank, rank in enumerate(
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
proc = Process(target=main,
args=(
local_rank,
rank,
master_addr,
master_port,
args.model_weight_gib,
args.model,
world_size,
tp_size,
args.enable_expert_parallel,
args.enforce_eager,
args.trust_remote_code,
args.enable_sleep_mode,
args.temperature,
args.sleep_mode_level,
))
for local_rank, rank in enumerate(range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
proc = Process(
target=main,
args=(
local_rank,
rank,
master_addr,
master_port,
args.model_weight_gib,
args.model,
world_size,
tp_size,
args.enable_expert_parallel,
args.enforce_eager,
args.trust_remote_code,
args.enable_sleep_mode,
args.temperature,
args.sleep_mode_level,
),
)
proc.start()
procs.append(proc)
@@ -320,9 +303,7 @@ if __name__ == "__main__":
for proc in procs:
proc.join(timeout=600)
if proc.exitcode is None:
print(
f"Killing process {proc.pid} that didn't stop within 30 minutes."
)
print(f"Killing process {proc.pid} that didn't stop within 30 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:

View File

@@ -17,19 +17,20 @@
# Adapted from vllm-project/vllm/examples/offline_inference/audio_language.py
#
"""
This example shows how to use vLLM for running offline inference
This example shows how to use vLLM for running offline inference
with the correct prompt format on audio language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
import argparse
import os
from vllm.assets.audio import AudioAsset
try:
import librosa # type: ignore
import librosa # type: ignore
except ImportError:
raise Exception("Can't import librosa, please ensure it's installed")
@@ -40,7 +41,7 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str):
use_vllm_audio_assert = True if audio_path1 == "mary_had_lamb" and audio_path2 == "winning_call" else False
use_vllm_audio_assert = audio_path1 == "mary_had_lamb" and audio_path2 == "winning_call"
if use_vllm_audio_assert:
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
else:
@@ -48,22 +49,22 @@ def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str):
question_per_audio_count = {
1: "What is recited in the audio?",
2: "What sport and what nursery rhyme are referenced?"
2: "What sport and what nursery rhyme are referenced?",
}
audio_in_prompt = "".join([
f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
for idx in range(audio_count)
])
audio_in_prompt = "".join([f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)])
question = question_per_audio_count[audio_count]
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
prompt = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
mm_data = {
"audio":
audio_assets if not use_vllm_audio_assert else [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
"audio": audio_assets
if not use_vllm_audio_assert
else [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
}
# Merge text prompt and audio data into inputs
@@ -76,17 +77,17 @@ def main(audio_count: int, audio_path1: str, audio_path2: str):
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
# `limit_mm_per_prompt`: the max num items for each modality per prompt.
llm = LLM(model="Qwen/Qwen2-Audio-7B-Instruct",
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True)
llm = LLM(
model="Qwen/Qwen2-Audio-7B-Instruct",
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
enforce_eager=True,
)
inputs = prepare_inputs(audio_count, audio_path1, audio_path2)
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=None)
sampling_params = SamplingParams(temperature=0.2, max_tokens=64, stop_token_ids=None)
outputs = llm.generate(inputs, sampling_params=sampling_params)
@@ -96,7 +97,9 @@ def main(audio_count: int, audio_path1: str, audio_path2: str):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Arguments of rank table generator", )
parser = argparse.ArgumentParser(
description="Arguments of rank table generator",
)
parser.add_argument("--audio-path1", type=str, default="mary_had_lamb")
parser.add_argument("--audio-path2", type=str, default="winning_call")
args = parser.parse_args()

View File

@@ -1,6 +1,6 @@
import argparse
import os
import time
import argparse
from vllm import LLM, SamplingParams
@@ -11,14 +11,14 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_len', type=int, default=1024)
parser.add_argument('--output_len', type=int, default=128)
parser.add_argument('--bs', type=int, default=1)
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite")
parser.add_argument('--tp', type=int, default=2)
parser.add_argument('--pcp', type=int, default=2)
parser.add_argument('--dcp', type=int, default=1)
parser.add_argument('--iter_times', type=int, default=1)
parser.add_argument("--input_len", type=int, default=1024)
parser.add_argument("--output_len", type=int, default=128)
parser.add_argument("--bs", type=int, default=1)
parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-V2-Lite")
parser.add_argument("--tp", type=int, default=2)
parser.add_argument("--pcp", type=int, default=2)
parser.add_argument("--dcp", type=int, default=1)
parser.add_argument("--iter_times", type=int, default=1)
args = parser.parse_args()
@@ -26,10 +26,10 @@ if __name__ == "__main__":
"The capital of France is",
"Hello, my name is Tom, I am",
"The president of United States is",
"AI future is"
"AI future is",
]
sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=args.output_len)
llm = LLM(
model=args.model_path,
trust_remote_code=True,
@@ -44,7 +44,7 @@ if __name__ == "__main__":
max_model_len=1024,
max_num_seqs=1,
block_size=128,
gpu_memory_utilization=0.9
gpu_memory_utilization=0.9,
)
t0 = time.time()
@@ -56,4 +56,4 @@ if __name__ == "__main__":
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"req_num: {i}\nGenerated text: {generated_text!r}")
print(f"req_num: {i}\nGenerated text: {generated_text!r}")

View File

@@ -37,11 +37,13 @@ def main():
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
tensor_parallel_size=2,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024)
llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite",
tensor_parallel_size=2,
enforce_eager=True,
trust_remote_code=True,
max_model_len=1024,
)
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)

View File

@@ -25,11 +25,12 @@ from vllm.utils.mem_constants import GiB_bytes
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def main():
prompt = "How are you?"
free, total = torch.npu.mem_get_info()
print(f"Free memory before sleep: {free / 1024 ** 3:.2f} GiB")
print(f"Free memory before sleep: {free / 1024**3:.2f} GiB")
# record npu memory use baseline in case other process is running
used_bytes_baseline = total - free
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)
@@ -39,9 +40,7 @@ def main():
llm.sleep(level=1)
free_npu_bytes_after_sleep, total = torch.npu.mem_get_info()
print(
f"Free memory after sleep: {free_npu_bytes_after_sleep / 1024 ** 3:.2f} GiB"
)
print(f"Free memory after sleep: {free_npu_bytes_after_sleep / 1024**3:.2f} GiB")
used_bytes = total - free_npu_bytes_after_sleep - used_bytes_baseline
# now the memory usage should be less than the model weights
# (0.5B model, 1GiB weights)

View File

@@ -63,19 +63,21 @@ from multiprocessing import Process
from time import sleep
import torch
from safetensors.torch import load_file
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
from safetensors.torch import load_file
destroy_distributed_environment,
destroy_model_parallel,
get_tp_group,
)
from vllm.model_executor.model_loader.utils import process_weights_after_loading
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_open_port
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def patch_vllm_moe_model_weight_loader(model):
# Define MLP attribute mapping for different model types
@@ -92,24 +94,25 @@ def patch_vllm_moe_model_weight_loader(model):
if "w13_weight" in name or "w2_weight" in name:
param.weight_loader = mlp.experts.weight_loader
def load_and_merge_safetensors(directory):
merged_dict = {}
if not os.path.isdir(directory):
raise ValueError(f"directory is not exist : {directory}")
for filename in os.listdir(directory):
if filename.endswith('.safetensors'):
if filename.endswith(".safetensors"):
file_path = os.path.join(directory, filename)
print(f"loading file: {file_path}")
f = load_file(file_path)
merged_dict.update(f)
return merged_dict
def parse_args():
def parse_args():
parser = argparse.ArgumentParser(description="External launcher Inference")
parser.add_argument(
"--model",
@@ -117,55 +120,34 @@ def parse_args():
default="Qwen/Qwen3-0.6B",
help="Model name or path",
)
parser.add_argument("--tp-size",
type=int,
default=1,
help="Tensor parallel size")
parser.add_argument("--node-size",
type=int,
default=1,
help="Total number of nodes")
parser.add_argument("--node-rank",
type=int,
default=0,
help="Rank of the current node")
parser.add_argument("--proc-per-node",
type=int,
default=1,
help="Number of processes per node")
parser.add_argument("--master-addr",
type=str,
default="",
help="Master node IP address")
parser.add_argument("--master-port",
type=int,
default=0,
help="Master node port")
parser.add_argument("--enforce-eager",
action="store_true",
help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code",
action="store_true",
help="Trust remote code.")
parser.add_argument("--enable-expert-parallel",
action="store_true",
help="Enable expert parallel, used in MOE models.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
help="Enable sleep mode for the engine.")
parser.add_argument("--temperature",
type=float,
default=0.8,
help="Float that controls the randomness of the sampling.")
parser.add_argument("--model-weight-gib",
type=float,
default=None,
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
parser.add_argument("--proc-per-node", type=int, default=1, help="Number of processes per node")
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
parser.add_argument(
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
)
parser.add_argument("--enable-sleep-mode", action="store_true", help="Enable sleep mode for the engine.")
parser.add_argument(
"--temperature", type=float, default=0.8, help="Float that controls the randomness of the sampling."
)
parser.add_argument(
"--model-weight-gib",
type=float,
default=None,
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).",
)
args = parser.parse_args()
if args.enable_sleep_mode:
if args.model_weight_gib is None or args.temperature != 0:
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
parser.error(
"model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set."
)
if args.model_weight_gib <= 0:
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
if args.model == parser.get_default("model") and args.model_weight_gib is None:
@@ -219,7 +201,7 @@ def main(
trust_remote_code=trust_remote_code,
distributed_executor_backend="external_launcher",
seed=0,
gpu_memory_utilization = 0.95,
gpu_memory_utilization=0.95,
enable_sleep_mode=enable_sleep_mode,
)
outputs = llm.generate(prompts, sampling_params)
@@ -231,7 +213,7 @@ def main(
if rank == 0:
free_bytes_after_sleep, total = torch.npu.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
print(f"Freed memory: {freed_bytes / 1024**3:.2f} GiB")
# now the freed memory should be larger than the model weights
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
@@ -242,9 +224,9 @@ def main(
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
print("load state dict done")
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
print(f"TP RANKS: {tp_ranks}")
vllm_config = llm.llm_engine.vllm_config.model_config
device = next(runmodel.parameters()).device
@@ -262,8 +244,7 @@ def main(
break
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
print(f"Global rank: {rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
# Give engines time to pause their processing loops before exiting.
sleep(5)
@@ -299,24 +280,25 @@ if __name__ == "__main__":
world_size = node_size * proc_per_node
procs = []
for local_rank, rank in enumerate(
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
proc = Process(target=main,
args=(
local_rank,
rank,
master_addr,
master_port,
args.model_weight_gib,
args.model,
world_size,
tp_size,
args.enable_expert_parallel,
args.enforce_eager,
args.trust_remote_code,
args.enable_sleep_mode,
args.temperature,
))
for local_rank, rank in enumerate(range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
proc = Process(
target=main,
args=(
local_rank,
rank,
master_addr,
master_port,
args.model_weight_gib,
args.model,
world_size,
tp_size,
args.enable_expert_parallel,
args.enforce_eager,
args.trust_remote_code,
args.enable_sleep_mode,
args.temperature,
),
)
proc.start()
procs.append(proc)
@@ -324,9 +306,7 @@ if __name__ == "__main__":
for proc in procs:
proc.join(timeout=600)
if proc.exitcode is None:
print(
f"Killing process {proc.pid} that didn't stop within 30 minutes."
)
print(f"Killing process {proc.pid} that didn't stop within 30 minutes.")
proc.kill()
exit_code = 1
elif proc.exitcode:

View File

@@ -20,7 +20,6 @@ Run:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from vllm import LLM
@@ -37,16 +36,12 @@ def get_prompt_embeds(
tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module,
):
token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt")
prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds
def single_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
@@ -63,18 +58,14 @@ def single_prompt_inference(
print("-" * 30)
def batch_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
chats = [
[{"role": "user", "content": "Please tell me about the capital of France."}],
[{"role": "user", "content": "When is the day longest during the year?"}],
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
]
prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
]
prompt_embeds_list = [get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats]
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])

View File

@@ -1,8 +1,7 @@
import os
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizer)
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"] = "True"
@@ -17,27 +16,21 @@ def init_tokenizer_and_llm(model_name: str):
return tokenizer, embedding_layer, llm
def get_prompt_embeds(chat: list[dict[str,
str]], tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module):
token_ids = tokenizer.apply_chat_template(chat,
add_generation_prompt=True,
return_tensors='pt')
def get_prompt_embeds(chat: list[dict[str, str]], tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt")
prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module):
chat = [{
"role": "user",
"content": "Please tell me about the capital of France."
}]
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
outputs = llm.generate({
"prompt_embeds": prompt_embeds,
})
outputs = llm.generate(
{
"prompt_embeds": prompt_embeds,
}
)
print("\n[Single Inference Output]")
print("-" * 30)
@@ -46,34 +39,22 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
print("-" * 30)
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module):
chats = [[{
"role": "user",
"content": "Please tell me about the capital of France."
}],
[{
"role": "user",
"content": "When is the day longest during the year?"
}],
[{
"role": "user",
"content": "Where is bigger, the moon or the sun?"
}]]
prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
chats = [
[{"role": "user", "content": "Please tell me about the capital of France."}],
[{"role": "user", "content": "When is the day longest during the year?"}],
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
]
outputs = llm.generate([{
"prompt_embeds": embeds
} for embeds in prompt_embeds_list])
prompt_embeds_list = [get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats]
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
print("\n[Batch Inference Outputs]")
print("-" * 30)
for i, o in enumerate(outputs):
print(f"Q{i+1}: {chats[i][0]['content']}")
print(f"A{i+1}: {o.outputs[0].text}\n")
print(f"Q{i + 1}: {chats[i][0]['content']}")
print(f"A{i + 1}: {o.outputs[0].text}\n")
print("-" * 30)

View File

@@ -1,31 +1,25 @@
import os
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationStrategy, QuantizationType
from datasets import load_dataset
from transformers import AutoModelForCausalLM, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, \
AutoTokenizer, AutoProcessor, AutoConfig, AutoImageProcessor
from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationType, QuantizationStrategy
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
)
W8A8_W_cha_A_ten_static_symmetric = {
"group_0": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.CHANNEL,
symmetric=True,
dynamic=False
num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.CHANNEL, symmetric=True, dynamic=False
),
input_activations=QuantizationArgs(
num_bits=8,
type=QuantizationType.INT,
strategy=QuantizationStrategy.TENSOR,
symmetric=True,
dynamic=False
num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.TENSOR, symmetric=True, dynamic=False
),
),
}
@@ -53,19 +47,19 @@ TOKENIZER_DICT = {
def load_environment_variables():
env_vars = {
'model_path': "Qwen/Qwen3-32B",
'export_path': "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
'modifier': "GPTQ",
'schemes': "W8A8_W_cha_A_ten_static_symmetric",
'calib_prompt_path': "HuggingFaceH4/ultrachat_200k"
"model_path": "Qwen/Qwen3-32B",
"export_path": "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
"modifier": "GPTQ",
"schemes": "W8A8_W_cha_A_ten_static_symmetric",
"calib_prompt_path": "HuggingFaceH4/ultrachat_200k",
}
# verify export model path
if env_vars['export_path'] is None:
env_vars['export_path'] = env_vars['model_path'].rstrip("/") + "-" + env_vars['modifier']
if env_vars['schemes'] is not None:
env_vars['export_path'] += "-" + env_vars['schemes']
os.makedirs(env_vars['export_path'], exist_ok=True)
if env_vars["export_path"] is None:
env_vars["export_path"] = env_vars["model_path"].rstrip("/") + "-" + env_vars["modifier"]
if env_vars["schemes"] is not None:
env_vars["export_path"] += "-" + env_vars["schemes"]
os.makedirs(env_vars["export_path"], exist_ok=True)
return env_vars
@@ -74,19 +68,17 @@ def load_calibration_text_dataset(calib_prompt_path, tokenizer):
# Load dataset
for f in os.listdir(calib_prompt_path):
print(f)
if any(f.lower().endswith('.jsonl') for f in os.listdir(calib_prompt_path)):
ds = load_dataset('json', data_dir=calib_prompt_path, split='validation')
elif any(f.lower().endswith('.parquet') for f in os.listdir(calib_prompt_path)):
if any(f.lower().endswith(".jsonl") for f in os.listdir(calib_prompt_path)):
ds = load_dataset("json", data_dir=calib_prompt_path, split="validation")
elif any(f.lower().endswith(".parquet") for f in os.listdir(calib_prompt_path)):
ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]")
else:
raise ValueError("Unsupported calibration file format: {}".format(
calib_prompt_path.split('.')[-1]))
raise ValueError("Unsupported calibration file format: {}".format(calib_prompt_path.split(".")[-1]))
# Preprocess dataset
def preprocess(example):
if tokenizer.chat_template is not None:
return {"text": tokenizer.apply_chat_template(
example["messages"], tokenize=False)}
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
else:
return {"text": example["messages"]}
@@ -118,8 +110,8 @@ def quantize_model(model, env_vars, dataset_dict=None):
# define a llmcompressor recipe
recipe = [
MODIFIER_DICT[env_vars['modifier']](
config_groups=SCHEMES_DICT[env_vars['schemes']],
MODIFIER_DICT[env_vars["modifier"]](
config_groups=SCHEMES_DICT[env_vars["schemes"]],
ignore=ignore,
),
]
@@ -138,18 +130,16 @@ def save_quantized_model(model, tokenizer, save_path, save_compressed=False):
tokenizer.save_pretrained(save_path)
if __name__ == '__main__':
if __name__ == "__main__":
# get environment variables
env_vars = load_environment_variables()
# support model type list
config = AutoConfig.from_pretrained(env_vars['model_path'], trust_remote_code=True)
config = AutoConfig.from_pretrained(env_vars["model_path"], trust_remote_code=True)
model_type = config.model_type
model = MODEL_DICT[model_type].from_pretrained(
env_vars['model_path'], torch_dtype="auto", trust_remote_code=True
)
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars['model_path'], trust_remote_code=True)
model = MODEL_DICT[model_type].from_pretrained(env_vars["model_path"], torch_dtype="auto", trust_remote_code=True)
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars["model_path"], trust_remote_code=True)
ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer)
@@ -157,4 +147,4 @@ if __name__ == '__main__':
quantize_model(model, env_vars, ds)
# save the quantized model
save_quantized_model(model, tokenizer, env_vars['export_path'], True)
save_quantized_model(model, tokenizer, env_vars["export_path"], True)

View File

@@ -1,10 +1,9 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
from llmcompressor.utils import dispatch_for_generation
from transformers import AutoModelForCausalLM, AutoTokenizer
# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -80,4 +79,4 @@ print("==========================================\n\n")
# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-Dynamic-Per-Token"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)

View File

@@ -48,7 +48,6 @@ plugins.md029.enabled = false # ol-prefix
line-length = 120
# Folder to be modified
exclude = [
"examples/**",
"tests/**",
"vllm_ascend/**",
]