[Lint]Style: Convert example to ruff format (#5863)
### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
@@ -4,7 +4,6 @@ default_install_hook_types:
|
|||||||
default_stages:
|
default_stages:
|
||||||
- pre-commit # Run locally
|
- pre-commit # Run locally
|
||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
exclude: 'examples/.*' # Exclude examples from all hooks by default
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.14.0
|
rev: v0.14.0
|
||||||
|
|||||||
@@ -91,10 +91,8 @@ import heapq
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
@@ -106,28 +104,28 @@ logger = init_logger(__name__)
|
|||||||
# Add uvloop for faster event loop if available
|
# Add uvloop for faster event loop if available
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServerState:
|
class ServerState:
|
||||||
|
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.url = f'http://{host}:{port}/v1'
|
self.url = f"http://{host}:{port}/v1"
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(self.host)
|
ip = ipaddress.ip_address(self.host)
|
||||||
if isinstance(ip, ipaddress.IPv6Address):
|
if isinstance(ip, ipaddress.IPv6Address):
|
||||||
self.url = f'http://[{host}]:{port}/v1'
|
self.url = f"http://[{host}]:{port}/v1"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.client = httpx.AsyncClient(timeout=None,
|
self.client = httpx.AsyncClient(
|
||||||
base_url=self.url,
|
timeout=None,
|
||||||
limits=httpx.Limits(
|
base_url=self.url,
|
||||||
max_connections=100000,
|
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
||||||
max_keepalive_connections=100000))
|
)
|
||||||
self.active_tokens = 0
|
self.active_tokens = 0
|
||||||
self.active_kv_cache = 0 # Only for prefiller
|
self.active_kv_cache = 0 # Only for prefiller
|
||||||
self.active_requests = 0 # Number of active requests
|
self.active_requests = 0 # Number of active requests
|
||||||
@@ -136,14 +134,9 @@ class ServerState:
|
|||||||
|
|
||||||
|
|
||||||
class ProxyState:
|
class ProxyState:
|
||||||
|
|
||||||
def __init__(self, prefiller_instances, decoder_instances):
|
def __init__(self, prefiller_instances, decoder_instances):
|
||||||
self.prefillers: List[ServerState] = [
|
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
||||||
ServerState(h, p) for h, p in prefiller_instances
|
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
|
||||||
]
|
|
||||||
self.decoders: List[ServerState] = [
|
|
||||||
ServerState(h, p) for h, p in decoder_instances
|
|
||||||
]
|
|
||||||
self.req_to_prefiller = {}
|
self.req_to_prefiller = {}
|
||||||
self.req_id_lock = asyncio.Lock()
|
self.req_id_lock = asyncio.Lock()
|
||||||
# Removed selection locks - no longer needed for synchronous methods
|
# Removed selection locks - no longer needed for synchronous methods
|
||||||
@@ -151,10 +144,8 @@ class ProxyState:
|
|||||||
# Initialize priority queues for efficient server selection
|
# Initialize priority queues for efficient server selection
|
||||||
# Each entry is (priority_score, server_index, server_reference)
|
# Each entry is (priority_score, server_index, server_reference)
|
||||||
# Lower priority score = higher priority (less loaded)
|
# Lower priority score = higher priority (less loaded)
|
||||||
self.prefiller_heap = [(0, i, server)
|
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
|
||||||
for i, server in enumerate(self.prefillers)]
|
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
|
||||||
self.decoder_heap = [(0, i, server)
|
|
||||||
for i, server in enumerate(self.decoders)]
|
|
||||||
heapq.heapify(self.prefiller_heap)
|
heapq.heapify(self.prefiller_heap)
|
||||||
heapq.heapify(self.decoder_heap)
|
heapq.heapify(self.decoder_heap)
|
||||||
self.req_id_future = {}
|
self.req_id_future = {}
|
||||||
@@ -166,23 +157,18 @@ class ProxyState:
|
|||||||
# Priority based on active_tokens and active_kv_cache
|
# Priority based on active_tokens and active_kv_cache
|
||||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||||
# Remove old entry and add new one
|
# Remove old entry and add new one
|
||||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
|
||||||
if i != server_idx]
|
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
|
||||||
heapq.heappush(self.prefiller_heap,
|
|
||||||
(priority, server_idx, server)) # type: ignore
|
|
||||||
|
|
||||||
def _update_decoder_priority(self, server_idx: int):
|
def _update_decoder_priority(self, server_idx: int):
|
||||||
"""Update the priority of a decoder server in the heap."""
|
"""Update the priority of a decoder server in the heap."""
|
||||||
server = self.decoders[server_idx]
|
server = self.decoders[server_idx]
|
||||||
priority = server.active_tokens
|
priority = server.active_tokens
|
||||||
# Remove old entry and add new one
|
# Remove old entry and add new one
|
||||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
|
||||||
if i != server_idx]
|
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
|
||||||
heapq.heappush(self.decoder_heap,
|
|
||||||
(priority, server_idx, server)) # type: ignore
|
|
||||||
|
|
||||||
def abort_prefiller_request(self, server_idx: int,
|
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
|
||||||
request_id): # Changed to synchronous
|
|
||||||
"""
|
"""
|
||||||
Mark a request as aborted. This will helps to release kv cache in
|
Mark a request as aborted. This will helps to release kv cache in
|
||||||
prefiller node.
|
prefiller node.
|
||||||
@@ -190,8 +176,7 @@ class ProxyState:
|
|||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||||
|
|
||||||
def aquire_aborted_prefiller_requests(
|
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||||
self, server_idx: int): # Changed to synchronous
|
|
||||||
"""
|
"""
|
||||||
Get the set of aborted requests and clear it.
|
Get the set of aborted requests and clear it.
|
||||||
This is used to release kv cache in prefiller node.
|
This is used to release kv cache in prefiller node.
|
||||||
@@ -272,37 +257,20 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--prefiller-hosts",
|
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
|
||||||
type=str,
|
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
|
||||||
nargs="+",
|
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
|
||||||
default=["localhost"])
|
|
||||||
parser.add_argument("--prefiller-ports",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
default=[8001])
|
|
||||||
parser.add_argument("--decoder-hosts",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["localhost"])
|
|
||||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||||
parser.add_argument("--max-retries",
|
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Maximum number of retries for HTTP requests")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--retry-delay",
|
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
||||||
type=float,
|
)
|
||||||
default=0.001,
|
|
||||||
help="Base delay (seconds) for exponential backoff retries")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||||
raise ValueError(
|
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
|
||||||
"Number of prefiller hosts must match number of prefiller ports")
|
|
||||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||||
raise ValueError(
|
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
||||||
"Number of decoder hosts must match number of decoder ports")
|
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||||
args.prefiller_instances = list(
|
|
||||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
|
||||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -310,11 +278,8 @@ def parse_args():
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global proxy_state
|
global proxy_state
|
||||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
|
||||||
global_args.decoder_instances)
|
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
|
||||||
print(
|
|
||||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
|
||||||
)
|
|
||||||
yield
|
yield
|
||||||
for p in proxy_state.prefillers:
|
for p in proxy_state.prefillers:
|
||||||
await p.client.aclose()
|
await p.client.aclose()
|
||||||
@@ -331,14 +296,12 @@ async def listen_for_disconnect(request: Request) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def with_cancellation(handler_func):
|
def with_cancellation(handler_func):
|
||||||
|
|
||||||
@functools.wraps(handler_func)
|
@functools.wraps(handler_func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
request = kwargs["request"]
|
request = kwargs["request"]
|
||||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||||
return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
if handler_task in done:
|
if handler_task in done:
|
||||||
@@ -351,15 +314,16 @@ def with_cancellation(handler_func):
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
async def send_request_to_service(client: httpx.AsyncClient,
|
async def send_request_to_service(
|
||||||
prefiller_id: int,
|
client: httpx.AsyncClient,
|
||||||
endpoint: str,
|
prefiller_id: int,
|
||||||
req_data: dict,
|
endpoint: str,
|
||||||
request_id: str,
|
req_data: dict,
|
||||||
max_retries: int = 3,
|
request_id: str,
|
||||||
base_delay: float = 0.2):
|
max_retries: int = 3,
|
||||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
base_delay: float = 0.2,
|
||||||
prefiller_id)
|
):
|
||||||
|
proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
|
||||||
req_data = req_data.copy()
|
req_data = req_data.copy()
|
||||||
req_data["stream"] = False
|
req_data["stream"] = False
|
||||||
req_data["max_tokens"] = 1
|
req_data["max_tokens"] = 1
|
||||||
@@ -368,49 +332,38 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
|||||||
req_data["max_completion_tokens"] = 1
|
req_data["max_completion_tokens"] = 1
|
||||||
if "stream_options" in req_data:
|
if "stream_options" in req_data:
|
||||||
del req_data["stream_options"]
|
del req_data["stream_options"]
|
||||||
headers = {
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
||||||
"X-Request-Id": request_id
|
|
||||||
}
|
|
||||||
last_exc = None
|
last_exc = None
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
response = await client.post(endpoint,
|
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||||
json=req_data,
|
|
||||||
headers=headers)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
if request_id in proxy_state.req_id_future:
|
if request_id in proxy_state.req_id_future:
|
||||||
result_future = proxy_state.req_id_future[request_id]
|
result_future = proxy_state.req_id_future[request_id]
|
||||||
result_future.set_result(response.json()["kv_transfer_params"])
|
result_future.set_result(response.json()["kv_transfer_params"])
|
||||||
return
|
return
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
|
||||||
last_exc = e
|
last_exc = e
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for {endpoint}.")
|
|
||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
|
|
||||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
async def stream_service_response_with_retry(
|
||||||
endpoint: str,
|
client: httpx.AsyncClient,
|
||||||
req_data: dict,
|
endpoint: str,
|
||||||
request_id: str,
|
req_data: dict,
|
||||||
max_retries: int = 3,
|
request_id: str,
|
||||||
base_delay: float = 0.2):
|
max_retries: int = 3,
|
||||||
headers = {
|
base_delay: float = 0.2,
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
):
|
||||||
"X-Request-Id": request_id
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||||
}
|
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
async with client.stream("POST",
|
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
||||||
endpoint,
|
|
||||||
json=req_data,
|
|
||||||
headers=headers) as response:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
first_chunk_sent = False
|
first_chunk_sent = False
|
||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
@@ -419,32 +372,22 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
|||||||
return # Success, exit after streaming
|
return # Success, exit after streaming
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If any chunk has been sent, do not retry, just log and drop
|
# If any chunk has been sent, do not retry, just log and drop
|
||||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
if "first_chunk_sent" in locals() and first_chunk_sent:
|
||||||
logger.error(
|
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
||||||
f"Streaming to client interrupted after response started: {str(e)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@@ -469,15 +412,11 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
request_id = await proxy_state.next_req_id()
|
request_id = await proxy_state.next_req_id()
|
||||||
request_id_api = get_api_request_id(api, request_id)
|
request_id_api = get_api_request_id(api, request_id)
|
||||||
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
|
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
|
||||||
api)
|
req_data["kv_transfer_params"] = {
|
||||||
req_data['kv_transfer_params'] = {
|
"do_remote_decode": False,
|
||||||
"do_remote_decode":
|
"do_remote_prefill": True,
|
||||||
False,
|
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver",
|
||||||
"do_remote_prefill":
|
|
||||||
True,
|
|
||||||
"metaserver":
|
|
||||||
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
|
||||||
}
|
}
|
||||||
# Select decoder
|
# Select decoder
|
||||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||||
@@ -494,28 +433,30 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
# Only one await per chunk, minimal logic in loop
|
# Only one await per chunk, minimal logic in loop
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_service_response_with_retry(
|
async for chunk in stream_service_response_with_retry(
|
||||||
decoder.client,
|
decoder.client,
|
||||||
api,
|
api,
|
||||||
req_data,
|
req_data,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
max_retries=global_args.max_retries,
|
max_retries=global_args.max_retries,
|
||||||
base_delay=global_args.retry_delay):
|
base_delay=global_args.retry_delay,
|
||||||
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
f"Error during streaming from decoder {decoder.url}: {str(e)} "
|
||||||
|
f"the aborted request {request_id} will be routing to the target "
|
||||||
|
"prefiller when new request is ready to dispatch to it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# After streaming done, release tokens
|
# After streaming done, release tokens
|
||||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||||
|
|
||||||
return StreamingResponse(generate_stream(),
|
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||||
media_type="application/json")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
print("Error occurred in disagg prefill proxy server"
|
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||||
f" - {api} endpoint")
|
|
||||||
print(e)
|
print(e)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
raise
|
raise
|
||||||
@@ -538,7 +479,7 @@ async def healthcheck():
|
|||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"prefill_instances": len(proxy_state.prefillers),
|
"prefill_instances": len(proxy_state.prefillers),
|
||||||
"decode_instances": len(proxy_state.decoders)
|
"decode_instances": len(proxy_state.decoders),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -553,25 +494,24 @@ async def metaserver(request: Request):
|
|||||||
request_id = get_origin_request_id(api, request_id)
|
request_id = get_origin_request_id(api, request_id)
|
||||||
req_data["kv_transfer_params"] = kv_transfer_params
|
req_data["kv_transfer_params"] = kv_transfer_params
|
||||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||||
logger.debug(
|
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
|
||||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Select prefiller
|
# Select prefiller
|
||||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||||
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
|
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
|
||||||
# Send request to prefiller
|
# Send request to prefiller
|
||||||
response = await send_request_to_service(
|
await send_request_to_service(
|
||||||
prefiller.client,
|
prefiller.client,
|
||||||
prefiller_idx,
|
prefiller_idx,
|
||||||
api,
|
api,
|
||||||
req_data,
|
req_data,
|
||||||
request_id,
|
request_id,
|
||||||
max_retries=global_args.max_retries,
|
max_retries=global_args.max_retries,
|
||||||
base_delay=global_args.retry_delay)
|
base_delay=global_args.retry_delay,
|
||||||
|
)
|
||||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||||
proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score)
|
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Post metaserver failed with: {str(e)}")
|
logger.error(f"Post metaserver failed with: {str(e)}")
|
||||||
@@ -579,8 +519,9 @@ async def metaserver(request: Request):
|
|||||||
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
|
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
global global_args
|
global global_args
|
||||||
global_args = parse_args()
|
global_args = parse_args()
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Tuple, Dict
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
@@ -150,22 +150,21 @@ class InstanceType:
|
|||||||
|
|
||||||
|
|
||||||
class ServerState:
|
class ServerState:
|
||||||
|
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.url = f'http://{host}:{port}/v1'
|
self.url = f"http://{host}:{port}/v1"
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(self.host)
|
ip = ipaddress.ip_address(self.host)
|
||||||
if isinstance(ip, ipaddress.IPv6Address):
|
if isinstance(ip, ipaddress.IPv6Address):
|
||||||
self.url = f'http://[{host}]:{port}/v1'
|
self.url = f"http://[{host}]:{port}/v1"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.client = httpx.AsyncClient(timeout=None,
|
self.client = httpx.AsyncClient(
|
||||||
base_url=self.url,
|
timeout=None,
|
||||||
limits=httpx.Limits(
|
base_url=self.url,
|
||||||
max_connections=100000,
|
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
||||||
max_keepalive_connections=100000))
|
)
|
||||||
self.active_tokens = 0
|
self.active_tokens = 0
|
||||||
self.active_kv_cache = 0 # Only for prefiller
|
self.active_kv_cache = 0 # Only for prefiller
|
||||||
self.active_requests = 0 # Number of active requests
|
self.active_requests = 0 # Number of active requests
|
||||||
@@ -186,16 +185,11 @@ class ServerState:
|
|||||||
|
|
||||||
|
|
||||||
class ProxyState:
|
class ProxyState:
|
||||||
|
|
||||||
def __init__(self, prefiller_instances, decoder_instances):
|
def __init__(self, prefiller_instances, decoder_instances):
|
||||||
self.node_listener = NodeListener(self)
|
self.node_listener = NodeListener(self)
|
||||||
|
|
||||||
self.prefillers: List[ServerState] = [
|
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
||||||
ServerState(h, p) for h, p in prefiller_instances
|
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
|
||||||
]
|
|
||||||
self.decoders: List[ServerState] = [
|
|
||||||
ServerState(h, p) for h, p in decoder_instances
|
|
||||||
]
|
|
||||||
self.req_to_prefiller = {}
|
self.req_to_prefiller = {}
|
||||||
self.req_id_lock = asyncio.Lock()
|
self.req_id_lock = asyncio.Lock()
|
||||||
# Removed selection locks - no longer needed for synchronous methods
|
# Removed selection locks - no longer needed for synchronous methods
|
||||||
@@ -203,10 +197,8 @@ class ProxyState:
|
|||||||
# Initialize priority queues for efficient server selection
|
# Initialize priority queues for efficient server selection
|
||||||
# Each entry is (priority_score, server_index, server_reference)
|
# Each entry is (priority_score, server_index, server_reference)
|
||||||
# Lower priority score = higher priority (less loaded)
|
# Lower priority score = higher priority (less loaded)
|
||||||
self.prefiller_heap = [(0, i, server)
|
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
|
||||||
for i, server in enumerate(self.prefillers)]
|
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
|
||||||
self.decoder_heap = [(0, i, server)
|
|
||||||
for i, server in enumerate(self.decoders)]
|
|
||||||
heapq.heapify(self.prefiller_heap)
|
heapq.heapify(self.prefiller_heap)
|
||||||
heapq.heapify(self.decoder_heap)
|
heapq.heapify(self.decoder_heap)
|
||||||
|
|
||||||
@@ -216,23 +208,18 @@ class ProxyState:
|
|||||||
# Priority based on active_tokens and active_kv_cache
|
# Priority based on active_tokens and active_kv_cache
|
||||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||||
# Remove old entry and add new one
|
# Remove old entry and add new one
|
||||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
|
||||||
if i != server_idx]
|
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
|
||||||
heapq.heappush(self.prefiller_heap,
|
|
||||||
(priority, server_idx, server)) # type: ignore
|
|
||||||
|
|
||||||
def _update_decoder_priority(self, server_idx: int):
|
def _update_decoder_priority(self, server_idx: int):
|
||||||
"""Update the priority of a decoder server in the heap."""
|
"""Update the priority of a decoder server in the heap."""
|
||||||
server = self.decoders[server_idx]
|
server = self.decoders[server_idx]
|
||||||
priority = server.active_tokens
|
priority = server.active_tokens
|
||||||
# Remove old entry and add new one
|
# Remove old entry and add new one
|
||||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
|
||||||
if i != server_idx]
|
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
|
||||||
heapq.heappush(self.decoder_heap,
|
|
||||||
(priority, server_idx, server)) # type: ignore
|
|
||||||
|
|
||||||
def abort_prefiller_request(self, server_idx: int,
|
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
|
||||||
request_id): # Changed to synchronous
|
|
||||||
"""
|
"""
|
||||||
Mark a request as aborted. This will helps to release kv cache in
|
Mark a request as aborted. This will helps to release kv cache in
|
||||||
prefiller node.
|
prefiller node.
|
||||||
@@ -240,8 +227,7 @@ class ProxyState:
|
|||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||||
|
|
||||||
def aquire_aborted_prefiller_requests(
|
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||||
self, server_idx: int): # Changed to synchronous
|
|
||||||
"""
|
"""
|
||||||
Get the set of aborted requests and clear it.
|
Get the set of aborted requests and clear it.
|
||||||
This is used to release kv cache in prefiller node.
|
This is used to release kv cache in prefiller node.
|
||||||
@@ -314,9 +300,7 @@ class ProxyState:
|
|||||||
def calculate_decode_scores(self, request_length: int) -> float:
|
def calculate_decode_scores(self, request_length: int) -> float:
|
||||||
return request_length
|
return request_length
|
||||||
|
|
||||||
async def add_instances(
|
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
|
||||||
self, instance_type: str, instances: List[ServerState]
|
|
||||||
) -> Tuple[List[str], List[str]]:
|
|
||||||
added_nodes, waiting_nodes = [], []
|
added_nodes, waiting_nodes = [], []
|
||||||
for server in instances:
|
for server in instances:
|
||||||
is_valid = await self.node_listener.check_instance_status(server.client)
|
is_valid = await self.node_listener.check_instance_status(server.client)
|
||||||
@@ -332,7 +316,7 @@ class ProxyState:
|
|||||||
waiting_nodes.append(node)
|
waiting_nodes.append(node)
|
||||||
return added_nodes, waiting_nodes
|
return added_nodes, waiting_nodes
|
||||||
|
|
||||||
def add_prefillers(self, instances: List[ServerState]) -> None:
|
def add_prefillers(self, instances: list[ServerState]) -> None:
|
||||||
num_prefillers = len(self.prefillers)
|
num_prefillers = len(self.prefillers)
|
||||||
for idx, server in enumerate(instances):
|
for idx, server in enumerate(instances):
|
||||||
if server not in self.prefillers:
|
if server not in self.prefillers:
|
||||||
@@ -341,7 +325,7 @@ class ProxyState:
|
|||||||
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
||||||
self.print_status(f"Add prefiller instances: {instances}.")
|
self.print_status(f"Add prefiller instances: {instances}.")
|
||||||
|
|
||||||
def add_decoders(self, instances: List[ServerState]) -> None:
|
def add_decoders(self, instances: list[ServerState]) -> None:
|
||||||
num_decoders = len(self.decoders)
|
num_decoders = len(self.decoders)
|
||||||
for idx, server in enumerate(instances):
|
for idx, server in enumerate(instances):
|
||||||
if server not in self.decoders:
|
if server not in self.decoders:
|
||||||
@@ -350,7 +334,7 @@ class ProxyState:
|
|||||||
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
||||||
self.print_status(f"Add decoder instances: {instances}.")
|
self.print_status(f"Add decoder instances: {instances}.")
|
||||||
|
|
||||||
def remove_prefillers(self, instances: List[ServerState]) -> None:
|
def remove_prefillers(self, instances: list[ServerState]) -> None:
|
||||||
instances_to_remove = set(instances)
|
instances_to_remove = set(instances)
|
||||||
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||||
prefiller_heap_copy = self.prefiller_heap.copy()
|
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||||
@@ -367,7 +351,7 @@ class ProxyState:
|
|||||||
heapq.heapify(self.prefiller_heap)
|
heapq.heapify(self.prefiller_heap)
|
||||||
self.print_status(f"Remove prefiller instances: {instances}.")
|
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||||
|
|
||||||
def remove_decoders(self, instances: List[ServerState]) -> None:
|
def remove_decoders(self, instances: list[ServerState]) -> None:
|
||||||
instances_to_remove = set(instances)
|
instances_to_remove = set(instances)
|
||||||
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||||
decoder_heap_copy = self.decoder_heap.copy()
|
decoder_heap_copy = self.decoder_heap.copy()
|
||||||
@@ -387,7 +371,7 @@ class ProxyState:
|
|||||||
def print_status(self, msg: str) -> None:
|
def print_status(self, msg: str) -> None:
|
||||||
status = {
|
status = {
|
||||||
"prefill_instances": [str(server) for server in self.prefillers],
|
"prefill_instances": [str(server) for server in self.prefillers],
|
||||||
"decode_instances": [str(server) for server in self.decoders]
|
"decode_instances": [str(server) for server in self.decoders],
|
||||||
}
|
}
|
||||||
print(f"{msg} Status: {status}")
|
print(f"{msg} Status: {status}")
|
||||||
|
|
||||||
@@ -398,7 +382,7 @@ proxy_state = None
|
|||||||
class NodeListener:
|
class NodeListener:
|
||||||
def __init__(self, proxy):
|
def __init__(self, proxy):
|
||||||
self.proxy_state = proxy
|
self.proxy_state = proxy
|
||||||
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
|
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
|
||||||
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
||||||
self.listening_thread.start()
|
self.listening_thread.start()
|
||||||
|
|
||||||
@@ -424,9 +408,7 @@ class NodeListener:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
||||||
endpoint = "/models"
|
endpoint = "/models"
|
||||||
headers = {
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
response = await client.get(endpoint, headers=headers)
|
response = await client.get(endpoint, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -439,46 +421,29 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--prefiller-hosts",
|
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
|
||||||
type=str,
|
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
|
||||||
nargs="+",
|
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
|
||||||
default=["localhost"])
|
|
||||||
parser.add_argument("--prefiller-ports",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
default=[8001])
|
|
||||||
parser.add_argument("--decoder-hosts",
|
|
||||||
type=str,
|
|
||||||
nargs="+",
|
|
||||||
default=["localhost"])
|
|
||||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||||
parser.add_argument("--max-retries",
|
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Maximum number of retries for HTTP requests")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--retry-delay",
|
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
||||||
type=float,
|
)
|
||||||
default=0.001,
|
parser.add_argument(
|
||||||
help="Base delay (seconds) for exponential backoff retries")
|
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
|
||||||
parser.add_argument("--max-waiting-retries",
|
)
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Maximum number of retries for waiting nodes to be started")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--waiting-retry-interval",
|
"--waiting-retry-interval",
|
||||||
type=float,
|
type=float,
|
||||||
default=10,
|
default=10,
|
||||||
help="Check interval (seconds) for waiting nodes to be started")
|
help="Check interval (seconds) for waiting nodes to be started",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||||
raise ValueError(
|
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
|
||||||
"Number of prefiller hosts must match number of prefiller ports")
|
|
||||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||||
raise ValueError(
|
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
||||||
"Number of decoder hosts must match number of decoder ports")
|
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||||
args.prefiller_instances = list(
|
|
||||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
|
||||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -486,11 +451,8 @@ def parse_args():
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global proxy_state
|
global proxy_state
|
||||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
|
||||||
global_args.decoder_instances)
|
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
|
||||||
print(
|
|
||||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
|
||||||
)
|
|
||||||
yield
|
yield
|
||||||
for p in proxy_state.prefillers:
|
for p in proxy_state.prefillers:
|
||||||
await p.client.aclose()
|
await p.client.aclose()
|
||||||
@@ -507,14 +469,12 @@ async def listen_for_disconnect(request: Request) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def with_cancellation(handler_func):
|
def with_cancellation(handler_func):
|
||||||
|
|
||||||
@functools.wraps(handler_func)
|
@functools.wraps(handler_func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
request = kwargs["request"]
|
request = kwargs["request"]
|
||||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||||
return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
if handler_task in done:
|
if handler_task in done:
|
||||||
@@ -527,17 +487,18 @@ def with_cancellation(handler_func):
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
async def send_request_to_service(client: httpx.AsyncClient,
|
async def send_request_to_service(
|
||||||
prefiller_id: int,
|
client: httpx.AsyncClient,
|
||||||
endpoint: str,
|
prefiller_id: int,
|
||||||
req_data: dict,
|
endpoint: str,
|
||||||
request_id: str,
|
req_data: dict,
|
||||||
max_retries: int = 3,
|
request_id: str,
|
||||||
base_delay: float = 0.2):
|
max_retries: int = 3,
|
||||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
base_delay: float = 0.2,
|
||||||
prefiller_id)
|
):
|
||||||
|
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
|
||||||
req_data = req_data.copy()
|
req_data = req_data.copy()
|
||||||
req_data['kv_transfer_params'] = {
|
req_data["kv_transfer_params"] = {
|
||||||
"do_remote_decode": True,
|
"do_remote_decode": True,
|
||||||
"do_remote_prefill": False,
|
"do_remote_prefill": False,
|
||||||
"remote_engine_id": None,
|
"remote_engine_id": None,
|
||||||
@@ -553,46 +514,35 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
|||||||
req_data["max_completion_tokens"] = 1
|
req_data["max_completion_tokens"] = 1
|
||||||
if "stream_options" in req_data:
|
if "stream_options" in req_data:
|
||||||
del req_data["stream_options"]
|
del req_data["stream_options"]
|
||||||
headers = {
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
||||||
"X-Request-Id": request_id
|
|
||||||
}
|
|
||||||
last_exc = None
|
last_exc = None
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
response = await client.post(endpoint,
|
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||||
json=req_data,
|
|
||||||
headers=headers)
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
|
||||||
last_exc = e
|
last_exc = e
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for {endpoint}.")
|
|
||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
|
|
||||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
async def stream_service_response_with_retry(
|
||||||
endpoint: str,
|
client: httpx.AsyncClient,
|
||||||
req_data: dict,
|
endpoint: str,
|
||||||
request_id: str,
|
req_data: dict,
|
||||||
max_retries: int = 3,
|
request_id: str,
|
||||||
base_delay: float = 0.2):
|
max_retries: int = 3,
|
||||||
headers = {
|
base_delay: float = 0.2,
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
):
|
||||||
"X-Request-Id": request_id
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||||
}
|
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
async with client.stream("POST",
|
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
||||||
endpoint,
|
|
||||||
json=req_data,
|
|
||||||
headers=headers) as response:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
first_chunk_sent = False
|
first_chunk_sent = False
|
||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
@@ -601,41 +551,28 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
|||||||
return # Success, exit after streaming
|
return # Success, exit after streaming
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If any chunk has been sent, do not retry, just log and drop
|
# If any chunk has been sent, do not retry, just log and drop
|
||||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
if "first_chunk_sent" in locals() and first_chunk_sent:
|
||||||
logger.error(
|
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
||||||
f"Streaming to client interrupted after response started: {str(e)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def _handle_select_instance(api: str, req_data: Any,
|
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
|
||||||
request_length: int):
|
|
||||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||||
logger.debug(
|
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
|
||||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
|
||||||
)
|
|
||||||
request_id = await proxy_state.next_req_id()
|
request_id = await proxy_state.next_req_id()
|
||||||
# Select prefiller
|
# Select prefiller
|
||||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||||
@@ -648,10 +585,11 @@ async def _handle_select_instance(api: str, req_data: Any,
|
|||||||
req_data,
|
req_data,
|
||||||
request_id,
|
request_id,
|
||||||
max_retries=global_args.max_retries,
|
max_retries=global_args.max_retries,
|
||||||
base_delay=global_args.retry_delay)
|
base_delay=global_args.retry_delay,
|
||||||
|
)
|
||||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
kv_transfer_params = response_json.get("kv_transfer_params", {})
|
||||||
if kv_transfer_params:
|
if kv_transfer_params:
|
||||||
req_data["kv_transfer_params"] = kv_transfer_params
|
req_data["kv_transfer_params"] = kv_transfer_params
|
||||||
# Select decoder
|
# Select decoder
|
||||||
@@ -661,13 +599,15 @@ async def _handle_select_instance(api: str, req_data: Any,
|
|||||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||||
decoder = proxy_state.decoders[decoder_idx]
|
decoder = proxy_state.decoders[decoder_idx]
|
||||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||||
return InstanceInfo(request_id=request_id,
|
return InstanceInfo(
|
||||||
prefiller_idx=prefiller_idx,
|
request_id=request_id,
|
||||||
prefiller_score=prefiller_score,
|
prefiller_idx=prefiller_idx,
|
||||||
prefiller=prefiller,
|
prefiller_score=prefiller_score,
|
||||||
decoder=decoder,
|
prefiller=prefiller,
|
||||||
decoder_idx=decoder_idx,
|
decoder=decoder,
|
||||||
decoder_score=decoder_score)
|
decoder_idx=decoder_idx,
|
||||||
|
decoder_score=decoder_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -686,8 +626,7 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
req_data = await request.json()
|
req_data = await request.json()
|
||||||
req_body = await request.body()
|
req_body = await request.body()
|
||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
instance_info = await _handle_select_instance(api, req_data,
|
instance_info = await _handle_select_instance(api, req_data, request_length)
|
||||||
request_length)
|
|
||||||
stream_flag = bool(req_data.get("stream", False))
|
stream_flag = bool(req_data.get("stream", False))
|
||||||
chat_flag = "messages" in req_data
|
chat_flag = "messages" in req_data
|
||||||
|
|
||||||
@@ -713,34 +652,31 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
while retry:
|
while retry:
|
||||||
retry = False
|
retry = False
|
||||||
async for chunk in stream_service_response_with_retry(
|
async for chunk in stream_service_response_with_retry(
|
||||||
instance_info.decoder.client,
|
instance_info.decoder.client,
|
||||||
api,
|
api,
|
||||||
req_data,
|
req_data,
|
||||||
request_id=instance_info.request_id,
|
request_id=instance_info.request_id,
|
||||||
max_retries=global_args.max_retries,
|
max_retries=global_args.max_retries,
|
||||||
base_delay=global_args.retry_delay):
|
base_delay=global_args.retry_delay,
|
||||||
|
):
|
||||||
if not released_kv and chunk:
|
if not released_kv and chunk:
|
||||||
proxy_state.release_prefiller_kv(
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
||||||
instance_info.prefiller_idx,
|
|
||||||
instance_info.prefiller_score)
|
|
||||||
released_kv = True
|
released_kv = True
|
||||||
try:
|
try:
|
||||||
chunk_str = chunk.decode("utf-8").strip()
|
chunk_str = chunk.decode("utf-8").strip()
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
logger.debug(
|
logger.debug(f"Skipping chunk: {chunk}")
|
||||||
f"Skipping chunk: {chunk}")
|
|
||||||
yield chunk
|
yield chunk
|
||||||
continue
|
continue
|
||||||
if not chunk_str:
|
if not chunk_str:
|
||||||
continue
|
continue
|
||||||
if chunk_str.startswith("data: "):
|
if chunk_str.startswith("data: "):
|
||||||
chunk_str = chunk_str[len("data: "):]
|
chunk_str = chunk_str[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
chunk_json = json.loads(chunk_str)
|
chunk_json = json.loads(chunk_str)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# if chunk is [done], skip it.
|
# if chunk is [done], skip it.
|
||||||
logger.debug(
|
logger.debug(f"Skipping chunk: {chunk_str}")
|
||||||
f"Skipping chunk: {chunk_str}")
|
|
||||||
yield chunk
|
yield chunk
|
||||||
continue
|
continue
|
||||||
choices = chunk_json.get("choices", [])
|
choices = chunk_json.get("choices", [])
|
||||||
@@ -751,63 +687,52 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
delta = choice.get("delta") or {}
|
delta = choice.get("delta") or {}
|
||||||
message = choice.get("message") or {}
|
message = choice.get("message") or {}
|
||||||
content = (
|
content = delta.get("content") or message.get("content") or choice.get("text") or ""
|
||||||
delta.get("content")
|
|
||||||
or message.get("content")
|
|
||||||
or choice.get("text")
|
|
||||||
or ""
|
|
||||||
)
|
|
||||||
generated_token += content
|
generated_token += content
|
||||||
|
|
||||||
stop_reason = choice.get(
|
stop_reason = choice.get("stop_reason")
|
||||||
"stop_reason")
|
|
||||||
usage = chunk_json.get("usage", {})
|
usage = chunk_json.get("usage", {})
|
||||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
completion_tokens = (
|
||||||
(completion_tokens + usage.get("completion_tokens"))
|
(completion_tokens + 1)
|
||||||
|
if stream_flag
|
||||||
|
else (completion_tokens + usage.get("completion_tokens"))
|
||||||
|
)
|
||||||
if stop_reason == "recomputed":
|
if stop_reason == "recomputed":
|
||||||
retry = True
|
retry = True
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
if chat_flag:
|
if chat_flag:
|
||||||
messages[0][
|
messages[0]["content"] = origin_prompt + generated_token
|
||||||
"content"] = origin_prompt + generated_token
|
|
||||||
else:
|
else:
|
||||||
req_data[
|
req_data["prompt"] = origin_prompt + generated_token
|
||||||
"prompt"] = origin_prompt + generated_token
|
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||||
req_data[
|
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
|
||||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
|
||||||
tmp_request_length = len(
|
|
||||||
json.dumps(req_data).encode("utf-8"))
|
|
||||||
instance_info = await _handle_select_instance(
|
|
||||||
api, req_data, tmp_request_length)
|
|
||||||
break
|
break
|
||||||
if retry_count > 0 and not stream_flag:
|
if retry_count > 0 and not stream_flag:
|
||||||
if chat_flag:
|
if chat_flag:
|
||||||
choice["message"][
|
choice["message"]["content"] = generated_token
|
||||||
"content"] = generated_token
|
|
||||||
else:
|
else:
|
||||||
choice["text"] = generated_token
|
choice["text"] = generated_token
|
||||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
|
||||||
|
f"the aborted request {instance_info.request_id} will be routing to the target "
|
||||||
|
"prefiller when new request is ready to dispatch to it"
|
||||||
)
|
)
|
||||||
proxy_state.abort_prefiller_request(
|
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
|
||||||
instance_info.prefiller_idx, instance_info.request_id)
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
||||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
|
||||||
instance_info.prefiller_score)
|
|
||||||
|
|
||||||
# After streaming done, release tokens
|
# After streaming done, release tokens
|
||||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
|
||||||
instance_info.decoder_score)
|
|
||||||
|
|
||||||
return StreamingResponse(generate_stream(),
|
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||||
media_type="application/json")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
print("Error occurred in disagg prefill proxy server"
|
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||||
f" - {api} endpoint")
|
|
||||||
print(e)
|
print(e)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
raise
|
raise
|
||||||
@@ -821,20 +746,21 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
|||||||
if isinstance(instances, str):
|
if isinstance(instances, str):
|
||||||
instances = [instances]
|
instances = [instances]
|
||||||
instances = trans_instances(instances)
|
instances = trans_instances(instances)
|
||||||
all_msg = f"{adjust_mode} {instance_type} instances: " \
|
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
|
||||||
f"{[str(server) for server in instances]}."
|
|
||||||
|
|
||||||
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
||||||
return {"error": f"Instance type {instance_type} is not supported. "
|
return {
|
||||||
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
|
"error": f"Instance type {instance_type} is not supported. "
|
||||||
|
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
|
||||||
|
}
|
||||||
|
|
||||||
if adjust_mode == "add":
|
if adjust_mode == "add":
|
||||||
added_nodes, waiting_nodes = await proxy_state.add_instances(
|
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
|
||||||
instance_type, instances
|
|
||||||
)
|
|
||||||
if waiting_nodes:
|
if waiting_nodes:
|
||||||
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
|
all_msg = (
|
||||||
f"Instances {waiting_nodes} are waiting to be added."
|
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
|
||||||
|
f"Instances {waiting_nodes} are waiting to be added."
|
||||||
|
)
|
||||||
elif adjust_mode == "remove":
|
elif adjust_mode == "remove":
|
||||||
if instance_type == InstanceType.PREFILL:
|
if instance_type == InstanceType.PREFILL:
|
||||||
proxy_state.remove_prefillers(instances)
|
proxy_state.remove_prefillers(instances)
|
||||||
@@ -843,14 +769,14 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
|||||||
return {
|
return {
|
||||||
"message": all_msg,
|
"message": all_msg,
|
||||||
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||||
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
|
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def trans_instances(instances: List[str]) -> List[ServerState]:
|
def trans_instances(instances: list[str]) -> list[ServerState]:
|
||||||
server_list = []
|
server_list = []
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
h, p = instance.split(":")
|
h, p = instance.split(":")
|
||||||
@@ -875,7 +801,7 @@ async def healthcheck():
|
|||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"prefill_instances": len(proxy_state.prefillers),
|
"prefill_instances": len(proxy_state.prefillers),
|
||||||
"decode_instances": len(proxy_state.decoders)
|
"decode_instances": len(proxy_state.decoders),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -889,7 +815,7 @@ async def handle_remove_instances(request: Request):
|
|||||||
return await _handle_adjust_instances("remove", request)
|
return await _handle_adjust_instances("remove", request)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
global global_args
|
global global_args
|
||||||
global_args = parse_args()
|
global_args = parse_args()
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@@ -2,17 +2,17 @@
|
|||||||
|
|
||||||
## Environmental Dependencies
|
## Environmental Dependencies
|
||||||
|
|
||||||
* Software:
|
* Software:
|
||||||
* Python >= 3.10, < 3.12
|
* Python >= 3.10, < 3.12
|
||||||
* CANN == 8.3.rc2
|
* CANN == 8.3.rc2
|
||||||
* PyTorch == 2.8.0, torch-npu == 2.8.0
|
* PyTorch == 2.8.0, torch-npu == 2.8.0
|
||||||
* vLLM (same version as vllm-ascend)
|
* vLLM (same version as vllm-ascend)
|
||||||
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
|
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
|
||||||
|
|
||||||
The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is
|
The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is
|
||||||
|
|
||||||
* vllm: v0.10.1
|
* vllm: v0.10.1
|
||||||
* vllm-ascend: v0.10.1rc1
|
* vllm-ascend: v0.10.1rc1
|
||||||
|
|
||||||
## run
|
## run
|
||||||
|
|
||||||
@@ -84,7 +84,6 @@ Set `GLOO_SOCKET_IFNAME`, `TP_SOCKET_IFNAME`, and `HCCL_SOCKET_IFNAME` to the co
|
|||||||
`--gpu-memory-utilization`: Percentage of video memory occupied by the card<br>
|
`--gpu-memory-utilization`: Percentage of video memory occupied by the card<br>
|
||||||
`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.<br>
|
`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.<br>
|
||||||
|
|
||||||
|
|
||||||
### 2. Run `decode` Node
|
### 2. Run `decode` Node
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -151,7 +150,6 @@ python load_balance_proxy_server_example.py --host localhost --prefiller-hosts h
|
|||||||
`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
|
`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
|
||||||
`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.<br>
|
`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.<br>
|
||||||
|
|
||||||
|
|
||||||
### 4. Run Inference
|
### 4. Run Inference
|
||||||
|
|
||||||
Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script.
|
Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script.
|
||||||
|
|||||||
@@ -4,13 +4,11 @@ Expert parallelism load balancer (EPLB) for vLLM.
|
|||||||
The rearrangement algorithm is adapted from
|
The rearrangement algorithm is adapted from
|
||||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||||
"""
|
"""
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def balanced_packing(weight: torch.Tensor,
|
def balanced_packing(weight: torch.Tensor, num_packs: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
|
||||||
are as balanced as possible.
|
are as balanced as possible.
|
||||||
@@ -28,26 +26,18 @@ def balanced_packing(weight: torch.Tensor,
|
|||||||
groups_per_pack = num_groups // num_packs
|
groups_per_pack = num_groups // num_packs
|
||||||
|
|
||||||
if groups_per_pack == 1:
|
if groups_per_pack == 1:
|
||||||
pack_index = torch.arange(weight.size(-1),
|
pack_index = torch.arange(weight.size(-1), dtype=torch.int64, device=weight.device).expand(weight.shape)
|
||||||
dtype=torch.int64,
|
|
||||||
device=weight.device).expand(weight.shape)
|
|
||||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||||
return pack_index, rank_in_pack
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||||
pack_index = torch.full_like(weight,
|
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
|
||||||
fill_value=-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device='cpu')
|
|
||||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
pack_weights = [0] * num_packs
|
pack_weights = [0] * num_packs
|
||||||
pack_items = [0] * num_packs
|
pack_items = [0] * num_packs
|
||||||
for group in indices[i]:
|
for group in indices[i]:
|
||||||
pack = min(
|
pack = min((i for i in range(num_packs) if pack_items[i] < groups_per_pack), key=pack_weights.__getitem__)
|
||||||
(i
|
|
||||||
for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
|
||||||
key=pack_weights.__getitem__)
|
|
||||||
assert pack_items[pack] < groups_per_pack
|
assert pack_items[pack] < groups_per_pack
|
||||||
pack_index[i, group] = pack
|
pack_index[i, group] = pack
|
||||||
rank_in_pack[i, group] = pack_items[pack]
|
rank_in_pack[i, group] = pack_items[pack]
|
||||||
@@ -56,9 +46,7 @@ def balanced_packing(weight: torch.Tensor,
|
|||||||
return pack_index, rank_in_pack
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
|
|
||||||
def replicate_experts(
|
def replicate_experts(weight: torch.Tensor, num_phy: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
weight: torch.Tensor,
|
|
||||||
num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
|
||||||
|
|
||||||
@@ -75,8 +63,7 @@ def replicate_experts(
|
|||||||
num_redundant = num_phy - num_log
|
num_redundant = num_phy - num_log
|
||||||
assert num_redundant >= 0
|
assert num_redundant >= 0
|
||||||
device = weight.device
|
device = weight.device
|
||||||
phy2log = torch.arange(num_phy, dtype=torch.int64,
|
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||||
device=device).repeat(n, 1)
|
|
||||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||||
@@ -88,9 +75,9 @@ def replicate_experts(
|
|||||||
return phy2log, rank, logcnt
|
return phy2log, rank, logcnt
|
||||||
|
|
||||||
|
|
||||||
def rebalance_experts_hierarchical(weight: torch.Tensor,
|
def rebalance_experts_hierarchical(
|
||||||
num_physical_experts: int, num_groups: int,
|
weight: torch.Tensor, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int
|
||||||
num_nodes: int, num_gpus: int):
|
):
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
weight: [num_moe_layers, num_logical_experts]
|
weight: [num_moe_layers, num_logical_experts]
|
||||||
@@ -115,45 +102,37 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
|
|||||||
|
|
||||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||||
inv = torch.empty_like(perm)
|
inv = torch.empty_like(perm)
|
||||||
inv.scatter_(
|
inv.scatter_(1, perm, torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(perm.shape))
|
||||||
1, perm,
|
|
||||||
torch.arange(perm.size(1), dtype=torch.int64,
|
|
||||||
device=perm.device).expand(perm.shape))
|
|
||||||
return inv
|
return inv
|
||||||
|
|
||||||
# Step 1: pack groups to nodes
|
# Step 1: pack groups to nodes
|
||||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||||
group_pack_index, group_rank_in_pack = balanced_packing(
|
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||||
tokens_per_group, num_nodes)
|
log2mlog = (
|
||||||
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
|
((group_pack_index * groups_per_node + group_rank_in_pack) * group_size).unsqueeze(-1)
|
||||||
group_size).unsqueeze(-1) +
|
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||||
torch.arange(group_size,
|
).flatten(-2)
|
||||||
dtype=torch.int64,
|
|
||||||
device=group_pack_index.device)).flatten(-2)
|
|
||||||
mlog2log = inverse(log2mlog)
|
mlog2log = inverse(log2mlog)
|
||||||
|
|
||||||
# Step 2: construct redundant experts within nodes
|
# Step 2: construct redundant experts within nodes
|
||||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
tokens_per_mlog = weight.gather(-1, mlog2log).view(-1, num_logical_experts // num_nodes)
|
||||||
-1, num_logical_experts // num_nodes)
|
phy2mlog, phyrank, mlogcnt = replicate_experts(tokens_per_mlog, num_physical_experts // num_nodes)
|
||||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
|
||||||
tokens_per_mlog, num_physical_experts // num_nodes)
|
|
||||||
|
|
||||||
# Step 3: pack physical_experts to GPUs
|
# Step 3: pack physical_experts to GPUs
|
||||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
|
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||||
num_gpus // num_nodes)
|
|
||||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||||
pphy2phy = inverse(phy2pphy)
|
pphy2phy = inverse(phy2pphy)
|
||||||
|
|
||||||
pphy2mlog = phy2mlog.gather(
|
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
|
||||||
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
|
pphy2mlog = (
|
||||||
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
|
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||||
0,
|
+ torch.arange(0, num_logical_experts, num_logical_experts // num_nodes, device=group_pack_index.device).view(
|
||||||
num_logical_experts,
|
1, -1, 1
|
||||||
num_logical_experts // num_nodes,
|
)
|
||||||
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
|
).flatten(-2)
|
||||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||||
@@ -161,9 +140,8 @@ def rebalance_experts_hierarchical(weight: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def rebalance_experts(
|
def rebalance_experts(
|
||||||
weight: torch.Tensor, num_replicas: int, num_groups: int,
|
weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int
|
||||||
num_nodes: int,
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Entry point for expert-parallelism load balancer.
|
Entry point for expert-parallelism load balancer.
|
||||||
|
|
||||||
@@ -183,23 +161,20 @@ def rebalance_experts(
|
|||||||
weight = weight.float().cpu()
|
weight = weight.float().cpu()
|
||||||
if num_groups % num_nodes == 0:
|
if num_groups % num_nodes == 0:
|
||||||
# use hierarchical load-balance policy
|
# use hierarchical load-balance policy
|
||||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, num_groups, num_nodes, num_gpus)
|
||||||
weight, num_replicas, num_groups, num_nodes, num_gpus)
|
|
||||||
else:
|
else:
|
||||||
# use global load-balance policy
|
# use global load-balance policy
|
||||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(weight, num_replicas, 1, 1, num_gpus)
|
||||||
weight, num_replicas, 1, 1, num_gpus)
|
|
||||||
maxlogcnt = logcnt.max().item()
|
maxlogcnt = logcnt.max().item()
|
||||||
log2phy: torch.Tensor = torch.full(
|
log2phy: torch.Tensor = torch.full(
|
||||||
(num_layers, num_logical_experts, maxlogcnt),
|
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=torch.int64, device=logcnt.device
|
||||||
-1,
|
)
|
||||||
dtype=torch.int64,
|
|
||||||
device=logcnt.device)
|
|
||||||
log2phy.view(num_layers, -1).scatter_(
|
log2phy.view(num_layers, -1).scatter_(
|
||||||
-1, phy2log * maxlogcnt + phyrank,
|
-1,
|
||||||
torch.arange(num_replicas, dtype=torch.int64,
|
phy2log * maxlogcnt + phyrank,
|
||||||
device=log2phy.device).expand(num_layers, -1))
|
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(num_layers, -1),
|
||||||
|
)
|
||||||
return phy2log, log2phy, logcnt
|
return phy2log, log2phy, logcnt
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['rebalance_experts']
|
__all__ = ["rebalance_experts"]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
|
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -21,10 +20,7 @@ def save_matrix_to_json(output_path, file_name, deployment):
|
|||||||
layer = {"layer_id": i, "device_count": num_cards}
|
layer = {"layer_id": i, "device_count": num_cards}
|
||||||
device_list = []
|
device_list = []
|
||||||
for j in range(num_cards):
|
for j in range(num_cards):
|
||||||
device = {
|
device = {"device_id": j, "device_expert": deployment[i, j].tolist()}
|
||||||
"device_id": j,
|
|
||||||
"device_expert": deployment[i, j].tolist()
|
|
||||||
}
|
|
||||||
device_list.append(device)
|
device_list.append(device)
|
||||||
layer["device_list"] = device_list
|
layer["device_list"] = device_list
|
||||||
layer_list.append(layer)
|
layer_list.append(layer)
|
||||||
@@ -34,7 +30,7 @@ def save_matrix_to_json(output_path, file_name, deployment):
|
|||||||
|
|
||||||
# Save as JSON file
|
# Save as JSON file
|
||||||
try:
|
try:
|
||||||
with open(file_name, 'w') as f:
|
with open(file_name, "w") as f:
|
||||||
json.dump(data, f, indent=4)
|
json.dump(data, f, indent=4)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"write {file_name} failed: {e}")
|
print(f"write {file_name} failed: {e}")
|
||||||
@@ -63,21 +59,17 @@ def calculate_average(lst):
|
|||||||
return total / count
|
return total / count
|
||||||
|
|
||||||
|
|
||||||
def layer_imblance_polt(y_list, label_names, device_num, output_path,
|
def layer_imblance_polt(y_list, label_names, device_num, output_path, file_name):
|
||||||
file_name):
|
plt.rcParams["font.sans-serif"] = ["Arial"]
|
||||||
|
plt.rcParams["axes.unicode_minus"] = False
|
||||||
plt.rcParams['font.sans-serif'] = ['Arial']
|
|
||||||
plt.rcParams['axes.unicode_minus'] = False
|
|
||||||
x = [i for i in range(58)]
|
x = [i for i in range(58)]
|
||||||
for index, y in enumerate(y_list):
|
for index, y in enumerate(y_list):
|
||||||
plt.plot(x,
|
plt.plot(x, y, label=rf"{label_names[index]},avg={calculate_average(y)}")
|
||||||
y,
|
|
||||||
label=rf'{label_names[index]},avg={calculate_average(y)}')
|
|
||||||
|
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.title(rf'Load Distribution (num_gpus={device_num})')
|
plt.title(rf"Load Distribution (num_gpus={device_num})")
|
||||||
plt.xlabel('layer')
|
plt.xlabel("layer")
|
||||||
plt.ylabel('Device Load')
|
plt.ylabel("Device Load")
|
||||||
|
|
||||||
# Show grid lines
|
# Show grid lines
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
@@ -88,27 +80,23 @@ def layer_imblance_polt(y_list, label_names, device_num, output_path,
|
|||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
|
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes, num_gpus, num_original_expert):
|
||||||
num_gpus, num_original_expert):
|
|
||||||
from eplb_deepseek import rebalance_experts
|
from eplb_deepseek import rebalance_experts
|
||||||
|
|
||||||
num_replicas = num_original_expert + num_redundancy_expert
|
num_replicas = num_original_expert + num_redundancy_expert
|
||||||
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas,
|
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas, num_groups, num_nodes, num_gpus)
|
||||||
num_groups, num_nodes,
|
|
||||||
num_gpus)
|
|
||||||
|
|
||||||
# Convert to global_deployment
|
# Convert to global_deployment
|
||||||
workload = workload.cpu().numpy()
|
workload = workload.cpu().numpy()
|
||||||
global_deployment = []
|
global_deployment = []
|
||||||
layer_num = log2phy.shape[0]
|
layer_num = log2phy.shape[0]
|
||||||
num_physical_experts_local = (num_original_expert +
|
num_physical_experts_local = (num_original_expert + num_redundancy_expert) // num_gpus
|
||||||
num_redundancy_expert) // num_gpus
|
|
||||||
for layer_idx in range(layer_num):
|
for layer_idx in range(layer_num):
|
||||||
layer_deployment = []
|
layer_deployment = []
|
||||||
for gpu_idx in range(num_gpus):
|
for gpu_idx in range(num_gpus):
|
||||||
local_deployment = hy2log[layer_idx][gpu_idx *
|
local_deployment = hy2log[layer_idx][
|
||||||
num_physical_experts_local:
|
gpu_idx * num_physical_experts_local : (gpu_idx + 1) * num_physical_experts_local
|
||||||
(gpu_idx + 1) *
|
]
|
||||||
num_physical_experts_local]
|
|
||||||
local_deployment = local_deployment.flatten()
|
local_deployment = local_deployment.flatten()
|
||||||
layer_deployment.append(local_deployment.tolist())
|
layer_deployment.append(local_deployment.tolist())
|
||||||
global_deployment.append(layer_deployment)
|
global_deployment.append(layer_deployment)
|
||||||
@@ -122,18 +110,15 @@ def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
|
|||||||
new_value = workload[layer_idx].reshape(num_gpus, -1)
|
new_value = workload[layer_idx].reshape(num_gpus, -1)
|
||||||
row_sum = np.sum(new_value, axis=1)
|
row_sum = np.sum(new_value, axis=1)
|
||||||
original_weights.append(row_sum.max())
|
original_weights.append(row_sum.max())
|
||||||
average_weights.append((np.sum(workload[layer_idx]) / num_gpus))
|
average_weights.append(np.sum(workload[layer_idx]) / num_gpus)
|
||||||
|
|
||||||
opt_workload = np.zeros((num_original_expert + num_redundancy_expert),
|
opt_workload = np.zeros((num_original_expert + num_redundancy_expert), dtype=np.float64)
|
||||||
dtype=np.float64)
|
|
||||||
for expert_idx in range(num_original_expert):
|
for expert_idx in range(num_original_expert):
|
||||||
physical_expert_idxs = log2phy[layer_idx][expert_idx]
|
physical_expert_idxs = log2phy[layer_idx][expert_idx]
|
||||||
physical_expert_idxs = physical_expert_idxs.flatten()
|
physical_expert_idxs = physical_expert_idxs.flatten()
|
||||||
physical_expert_idxs = physical_expert_idxs[
|
physical_expert_idxs = physical_expert_idxs[physical_expert_idxs != -1]
|
||||||
physical_expert_idxs != -1]
|
|
||||||
for physical_expert_idx in physical_expert_idxs:
|
for physical_expert_idx in physical_expert_idxs:
|
||||||
opt_workload[physical_expert_idx] += workload[layer_idx][
|
opt_workload[physical_expert_idx] += workload[layer_idx][expert_idx] / len(physical_expert_idxs)
|
||||||
expert_idx] / len(physical_expert_idxs)
|
|
||||||
opt_workload = opt_workload.reshape(num_gpus, -1)
|
opt_workload = opt_workload.reshape(num_gpus, -1)
|
||||||
row_sum = np.sum(opt_workload, axis=1)
|
row_sum = np.sum(opt_workload, axis=1)
|
||||||
max_weights.append(row_sum.max())
|
max_weights.append(row_sum.max())
|
||||||
@@ -142,8 +127,9 @@ def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
|
|||||||
return global_deployment, y_list
|
return global_deployment, y_list
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--exp_name", type=str, default="gsm8k_temp0.0")
|
parser.add_argument("--exp_name", type=str, default="gsm8k_temp0.0")
|
||||||
parser.add_argument("--num_original_expert", type=int, default=256)
|
parser.add_argument("--num_original_expert", type=int, default=256)
|
||||||
@@ -165,19 +151,13 @@ if __name__ == '__main__':
|
|||||||
num_nodes = args.num_nodes
|
num_nodes = args.num_nodes
|
||||||
|
|
||||||
# NOTE: assume input workload format: [layer_num, num_experts]
|
# NOTE: assume input workload format: [layer_num, num_experts]
|
||||||
workload = torch.load(input_path, map_location=torch.device('cpu'))
|
workload = torch.load(input_path, map_location=torch.device("cpu"))
|
||||||
global_deployment, y_list = deepseek_deploy(workload,
|
global_deployment, y_list = deepseek_deploy(
|
||||||
num_redundancy_expert,
|
workload, num_redundancy_expert, num_groups, num_nodes, num_devices, num_original_expert
|
||||||
num_groups, num_nodes,
|
)
|
||||||
num_devices,
|
|
||||||
num_original_expert)
|
|
||||||
|
|
||||||
file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}"
|
file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}"
|
||||||
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
|
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
|
||||||
label_names = [
|
label_names = ["default deployment max load", "balanced load max load", "balanced load avg load"]
|
||||||
'default deployment max load', 'balanced load max load',
|
|
||||||
'balanced load avg load'
|
|
||||||
]
|
|
||||||
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
|
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
|
||||||
layer_imblance_polt(y_list, label_names, num_devices, output_path,
|
layer_imblance_polt(y_list, label_names, num_devices, output_path, new_file_name)
|
||||||
new_file_name)
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ All the arguments that can be set by users are:
|
|||||||
7. `--vllm-start-port`: Starting port of vLLM serving instances, default 9000
|
7. `--vllm-start-port`: Starting port of vLLM serving instances, default 9000
|
||||||
|
|
||||||
An example of running external DP in one single node:
|
An example of running external DP in one single node:
|
||||||
|
|
||||||
```(python)
|
```(python)
|
||||||
cd examples/external_online_dp
|
cd examples/external_online_dp
|
||||||
# running DP4 TP4 in a node with 16 NPUs
|
# running DP4 TP4 in a node with 16 NPUs
|
||||||
@@ -25,6 +26,7 @@ python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 4 --dp-rank-s
|
|||||||
```
|
```
|
||||||
|
|
||||||
An example of running external DP in two nodes:
|
An example of running external DP in two nodes:
|
||||||
|
|
||||||
```(python)
|
```(python)
|
||||||
cd examples/external_online_dp
|
cd examples/external_online_dp
|
||||||
# running DP4 TP4 in two nodes with 8 NPUs each
|
# running DP4 TP4 in two nodes with 8 NPUs each
|
||||||
|
|||||||
@@ -84,13 +84,12 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import heapq
|
import heapq
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
@@ -109,34 +108,29 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
class ServerState:
|
class ServerState:
|
||||||
|
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.url = f'http://{host}:{port}/v1'
|
self.url = f"http://{host}:{port}/v1"
|
||||||
self.client = httpx.AsyncClient(timeout=None,
|
self.client = httpx.AsyncClient(
|
||||||
base_url=self.url,
|
timeout=None,
|
||||||
limits=httpx.Limits(
|
base_url=self.url,
|
||||||
max_connections=100000,
|
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
||||||
max_keepalive_connections=100000))
|
)
|
||||||
self.active_tokens = 0
|
self.active_tokens = 0
|
||||||
self.aborted_requests = set() # Track aborted requests
|
self.aborted_requests = set() # Track aborted requests
|
||||||
|
|
||||||
|
|
||||||
class ProxyState:
|
class ProxyState:
|
||||||
|
|
||||||
def __init__(self, server_instances):
|
def __init__(self, server_instances):
|
||||||
self.dp_servers: List[ServerState] = [
|
self.dp_servers: list[ServerState] = [ServerState(h, p) for h, p in server_instances]
|
||||||
ServerState(h, p) for h, p in server_instances
|
|
||||||
]
|
|
||||||
self.req_id_lock = asyncio.Lock()
|
self.req_id_lock = asyncio.Lock()
|
||||||
# Removed selection locks - no longer needed for synchronous methods
|
# Removed selection locks - no longer needed for synchronous methods
|
||||||
|
|
||||||
# Initialize priority queues for efficient server selection
|
# Initialize priority queues for efficient server selection
|
||||||
# Each entry is (priority_score, server_index, server_reference)
|
# Each entry is (priority_score, server_index, server_reference)
|
||||||
# Lower priority score = higher priority (less loaded)
|
# Lower priority score = higher priority (less loaded)
|
||||||
self.lb_heap = [(0, i, server)
|
self.lb_heap = [(0, i, server) for i, server in enumerate(self.dp_servers)]
|
||||||
for i, server in enumerate(self.dp_servers)]
|
|
||||||
heapq.heapify(self.lb_heap)
|
heapq.heapify(self.lb_heap)
|
||||||
|
|
||||||
def _update_server_priority(self, server_idx: int):
|
def _update_server_priority(self, server_idx: int):
|
||||||
@@ -144,10 +138,8 @@ class ProxyState:
|
|||||||
server = self.dp_servers[server_idx]
|
server = self.dp_servers[server_idx]
|
||||||
priority = server.active_tokens
|
priority = server.active_tokens
|
||||||
# Remove old entry and add new one
|
# Remove old entry and add new one
|
||||||
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap
|
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap if i != server_idx]
|
||||||
if i != server_idx]
|
heapq.heappush(self.lb_heap, (priority, server_idx, server)) # type: ignore
|
||||||
heapq.heappush(self.lb_heap,
|
|
||||||
(priority, server_idx, server)) # type: ignore
|
|
||||||
|
|
||||||
async def next_req_id(self):
|
async def next_req_id(self):
|
||||||
async with self.req_id_lock:
|
async with self.req_id_lock:
|
||||||
@@ -190,27 +182,15 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--dp-hosts",
|
parser.add_argument("--dp-hosts", type=str, nargs="+", default=["localhost"])
|
||||||
type=str,
|
parser.add_argument("--dp-ports", type=int, nargs="+", default=[8001])
|
||||||
nargs="+",
|
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
||||||
default=["localhost"])
|
|
||||||
parser.add_argument("--dp-ports",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
default=[8001])
|
|
||||||
parser.add_argument("--max-retries",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="Maximum number of retries for HTTP requests")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--retry-delay",
|
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
||||||
type=float,
|
)
|
||||||
default=0.001,
|
|
||||||
help="Base delay (seconds) for exponential backoff retries")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if len(args.dp_hosts) != len(args.dp_ports):
|
if len(args.dp_hosts) != len(args.dp_ports):
|
||||||
raise ValueError(
|
raise ValueError("Number of dp hosts must match number of dp ports")
|
||||||
"Number of dp hosts must match number of dp ports")
|
|
||||||
args.server_instances = list(zip(args.dp_hosts, args.dp_ports))
|
args.server_instances = list(zip(args.dp_hosts, args.dp_ports))
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -219,9 +199,7 @@ def parse_args():
|
|||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
global proxy_state
|
global proxy_state
|
||||||
proxy_state = ProxyState(global_args.server_instances)
|
proxy_state = ProxyState(global_args.server_instances)
|
||||||
print(
|
print(f"Initialized {len(proxy_state.dp_servers)} dp server clients.")
|
||||||
f"Initialized {len(proxy_state.dp_servers)} dp server clients."
|
|
||||||
)
|
|
||||||
yield
|
yield
|
||||||
for p in proxy_state.dp_servers:
|
for p in proxy_state.dp_servers:
|
||||||
await p.client.aclose()
|
await p.client.aclose()
|
||||||
@@ -236,14 +214,12 @@ async def listen_for_disconnect(request: Request) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def with_cancellation(handler_func):
|
def with_cancellation(handler_func):
|
||||||
|
|
||||||
@functools.wraps(handler_func)
|
@functools.wraps(handler_func)
|
||||||
async def wrapper(*args, **kwargs):
|
async def wrapper(*args, **kwargs):
|
||||||
request = kwargs["request"]
|
request = kwargs["request"]
|
||||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||||
return_when=asyncio.FIRST_COMPLETED)
|
|
||||||
for task in pending:
|
for task in pending:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
if handler_task in done:
|
if handler_task in done:
|
||||||
@@ -256,22 +232,18 @@ def with_cancellation(handler_func):
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
async def stream_service_response_with_retry(
|
||||||
endpoint: str,
|
client: httpx.AsyncClient,
|
||||||
req_data: dict,
|
endpoint: str,
|
||||||
request_id: str,
|
req_data: dict,
|
||||||
max_retries: int = 3,
|
request_id: str,
|
||||||
base_delay: float = 0.2):
|
max_retries: int = 3,
|
||||||
headers = {
|
base_delay: float = 0.2,
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
):
|
||||||
"X-Request-Id": request_id
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||||
}
|
|
||||||
for attempt in range(1, max_retries + 1):
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
async with client.stream("POST",
|
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
||||||
endpoint,
|
|
||||||
json=req_data,
|
|
||||||
headers=headers) as response:
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
first_chunk_sent = False
|
first_chunk_sent = False
|
||||||
async for chunk in response.aiter_bytes():
|
async for chunk in response.aiter_bytes():
|
||||||
@@ -280,53 +252,42 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
|||||||
return # Success, exit after streaming
|
return # Success, exit after streaming
|
||||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If any chunk has been sent, do not retry, just log and drop
|
# If any chunk has been sent, do not retry, just log and drop
|
||||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
if "first_chunk_sent" in locals() and first_chunk_sent:
|
||||||
logger.error(
|
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
||||||
f"Streaming to client interrupted after response started: {str(e)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
if attempt < max_retries:
|
if attempt < max_retries:
|
||||||
logger.warning(
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||||
)
|
|
||||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
async def _select_instance(api: str, req_data: Any,
|
async def _select_instance(api: str, req_data: Any, request_length: int):
|
||||||
request_length: int):
|
|
||||||
# refer to vLLM sampling_params: max_token default value
|
# refer to vLLM sampling_params: max_token default value
|
||||||
max_tokens = req_data.get("max_tokens", 16)
|
max_tokens = req_data.get("max_tokens", 16)
|
||||||
ignore_eos = req_data.get("ignore_eos", False)
|
ignore_eos = req_data.get("ignore_eos", False)
|
||||||
priority_score = proxy_state.calculate_request_score(request_length,max_tokens=max_tokens, ignore_eos=ignore_eos)
|
priority_score = proxy_state.calculate_request_score(request_length, max_tokens=max_tokens, ignore_eos=ignore_eos)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Request length: {request_length}, max tokens: {max_tokens}, ignore_eos: {ignore_eos}, Priority score: {priority_score}"
|
f"Request length: {request_length}, max tokens: {max_tokens}, "
|
||||||
|
f"ignore_eos: {ignore_eos}, Priority score: {priority_score}"
|
||||||
)
|
)
|
||||||
request_id = await proxy_state.next_req_id()
|
request_id = await proxy_state.next_req_id()
|
||||||
# Select dp server based on priority score
|
# Select dp server based on priority score
|
||||||
server_idx = proxy_state.select_server(priority_score)
|
server_idx = proxy_state.select_server(priority_score)
|
||||||
choosen_server = proxy_state.dp_servers[server_idx]
|
choosen_server = proxy_state.dp_servers[server_idx]
|
||||||
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
|
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
|
||||||
return InstanceInfo(request_id=request_id,
|
return InstanceInfo(
|
||||||
server_idx=server_idx,
|
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=choosen_server
|
||||||
priority_score=priority_score,
|
)
|
||||||
server_state=choosen_server)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -342,36 +303,36 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
req_data = await request.json()
|
req_data = await request.json()
|
||||||
req_body = await request.body()
|
req_body = await request.body()
|
||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
instance_info = await _select_instance(api, req_data,
|
instance_info = await _select_instance(api, req_data, request_length)
|
||||||
request_length)
|
|
||||||
async def generate_stream():
|
async def generate_stream():
|
||||||
nonlocal instance_info
|
nonlocal instance_info
|
||||||
# Only one await per chunk, minimal logic in loop
|
# Only one await per chunk, minimal logic in loop
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_service_response_with_retry(
|
async for chunk in stream_service_response_with_retry(
|
||||||
instance_info.server_state.client,
|
instance_info.server_state.client,
|
||||||
api,
|
api,
|
||||||
req_data,
|
req_data,
|
||||||
request_id=instance_info.request_id,
|
request_id=instance_info.request_id,
|
||||||
max_retries=global_args.max_retries,
|
max_retries=global_args.max_retries,
|
||||||
base_delay=global_args.retry_delay):
|
base_delay=global_args.retry_delay,
|
||||||
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, the aborted request is: {instance_info.request_id}."
|
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, "
|
||||||
|
f"the aborted request is: {instance_info.request_id}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# After streaming done, release tokens
|
# After streaming done, release tokens
|
||||||
proxy_state.release_server(instance_info.server_idx,
|
proxy_state.release_server(instance_info.server_idx, instance_info.priority_score)
|
||||||
instance_info.priority_score)
|
|
||||||
|
|
||||||
return StreamingResponse(generate_stream(),
|
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||||
media_type="application/json")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
print("Error occurred in external dp proxy server"
|
print(f"Error occurred in external dp proxy server - {api} endpoint")
|
||||||
f" - {api} endpoint")
|
|
||||||
print(e)
|
print(e)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
raise
|
raise
|
||||||
@@ -397,7 +358,7 @@ async def healthcheck():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
global global_args
|
global global_args
|
||||||
global_args = parse_args()
|
global_args = parse_args()
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|||||||
@@ -4,52 +4,19 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument("--dp-size", type=int, required=True, help="Data parallel size.")
|
||||||
"--dp-size",
|
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size.")
|
||||||
type=int,
|
parser.add_argument("--dp-size-local", type=int, default=-1, help="Local data parallel size.")
|
||||||
required=True,
|
parser.add_argument("--dp-rank-start", type=int, default=0, help="Starting rank for data parallel.")
|
||||||
help="Data parallel size."
|
parser.add_argument("--dp-address", type=str, required=True, help="IP address for data parallel master node.")
|
||||||
)
|
parser.add_argument("--dp-rpc-port", type=str, default=12345, help="Port for data parallel master node.")
|
||||||
parser.add_argument(
|
parser.add_argument("--vllm-start-port", type=int, default=9000, help="Starting port for the engine.")
|
||||||
"--tp-size",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Tensor parallel size."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dp-size-local",
|
|
||||||
type=int,
|
|
||||||
default=-1,
|
|
||||||
help="Local data parallel size."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dp-rank-start",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Starting rank for data parallel."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dp-address",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="IP address for data parallel master node."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dp-rpc-port",
|
|
||||||
type=str,
|
|
||||||
default=12345,
|
|
||||||
help="Port for data parallel master node."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vllm-start-port",
|
|
||||||
type=int,
|
|
||||||
default=9000,
|
|
||||||
help="Starting port for the engine."
|
|
||||||
)
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
dp_size = args.dp_size
|
dp_size = args.dp_size
|
||||||
tp_size = args.tp_size
|
tp_size = args.tp_size
|
||||||
@@ -61,6 +28,7 @@ dp_address = args.dp_address
|
|||||||
dp_rpc_port = args.dp_rpc_port
|
dp_rpc_port = args.dp_rpc_port
|
||||||
vllm_start_port = args.vllm_start_port
|
vllm_start_port = args.vllm_start_port
|
||||||
|
|
||||||
|
|
||||||
def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
||||||
command = [
|
command = [
|
||||||
"bash",
|
"bash",
|
||||||
@@ -75,6 +43,7 @@ def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
|||||||
]
|
]
|
||||||
subprocess.run(command, check=True)
|
subprocess.run(command, check=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
template_path = "./run_dp_template.sh"
|
template_path = "./run_dp_template.sh"
|
||||||
if not os.path.exists(template_path):
|
if not os.path.exists(template_path):
|
||||||
@@ -87,9 +56,7 @@ if __name__ == "__main__":
|
|||||||
dp_rank = dp_rank_start + i
|
dp_rank = dp_rank_start + i
|
||||||
vllm_engine_port = vllm_start_port + i
|
vllm_engine_port = vllm_start_port + i
|
||||||
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
|
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
|
||||||
process = multiprocessing.Process(target=run_command,
|
process = multiprocessing.Process(target=run_command, args=(visiable_devices, dp_rank, vllm_engine_port))
|
||||||
args=(visiable_devices, dp_rank,
|
|
||||||
vllm_engine_port))
|
|
||||||
processes.append(process)
|
processes.append(process)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
||||||
|
|||||||
@@ -61,13 +61,13 @@ from time import sleep
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.distributed.parallel_state import ( # noqa E402
|
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel # noqa E402
|
||||||
destroy_distributed_environment, destroy_model_parallel)
|
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
|
|
||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@@ -78,43 +78,18 @@ def parse_args():
|
|||||||
default="ibm-research/PowerMoE-3b",
|
default="ibm-research/PowerMoE-3b",
|
||||||
help="Model name or path",
|
help="Model name or path",
|
||||||
)
|
)
|
||||||
parser.add_argument("--dp-size",
|
parser.add_argument("--dp-size", type=int, default=2, help="Data parallel size")
|
||||||
type=int,
|
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||||
default=2,
|
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
|
||||||
help="Data parallel size")
|
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
|
||||||
parser.add_argument("--tp-size",
|
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
|
||||||
type=int,
|
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
|
||||||
default=1,
|
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
|
||||||
help="Tensor parallel size")
|
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
|
||||||
parser.add_argument("--node-size",
|
parser.add_argument(
|
||||||
type=int,
|
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
|
||||||
default=1,
|
)
|
||||||
help="Total number of nodes")
|
parser.add_argument("--quantization", type=str, default="", help="Use quantization models")
|
||||||
parser.add_argument("--node-rank",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Rank of the current node")
|
|
||||||
parser.add_argument("--master-addr",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Master node IP address")
|
|
||||||
parser.add_argument("--master-port",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Master node port")
|
|
||||||
parser.add_argument("--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="Enforce eager mode execution.")
|
|
||||||
parser.add_argument("--trust-remote-code",
|
|
||||||
action="store_true",
|
|
||||||
help="Trust remote code.")
|
|
||||||
parser.add_argument("--enable-expert-parallel",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable expert parallel, used in MOE models.")
|
|
||||||
parser.add_argument("--quantization",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Use quantization models")
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +102,7 @@ def cleanup_env_and_memory():
|
|||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
torch.npu.reset_peak_memory_stats()
|
torch.npu.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
model,
|
model,
|
||||||
dp_size,
|
dp_size,
|
||||||
@@ -168,7 +144,7 @@ def main(
|
|||||||
def start(rank):
|
def start(rank):
|
||||||
return rank * floor + min(rank, remainder)
|
return rank * floor + min(rank, remainder)
|
||||||
|
|
||||||
prompts = prompts[start(global_dp_rank):start(global_dp_rank + 1)]
|
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
|
||||||
if len(prompts) == 0:
|
if len(prompts) == 0:
|
||||||
# if any rank has no prompts to process,
|
# if any rank has no prompts to process,
|
||||||
# we need to set a placeholder prompt
|
# we need to set a placeholder prompt
|
||||||
@@ -179,9 +155,7 @@ def main(
|
|||||||
# since we are doing data parallel, every rank can have different
|
# since we are doing data parallel, every rank can have different
|
||||||
# sampling params. here we set different max_tokens for different
|
# sampling params. here we set different max_tokens for different
|
||||||
# ranks for demonstration.
|
# ranks for demonstration.
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=[16, 20][global_dp_rank % 2])
|
||||||
top_p=0.95,
|
|
||||||
max_tokens=[16, 20][global_dp_rank % 2])
|
|
||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
@@ -200,14 +174,14 @@ def main(
|
|||||||
break
|
break
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, "
|
print(f"DP rank {global_dp_rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
# Give engines time to pause their processing loops before exiting.
|
# Give engines time to pause their processing loops before exiting.
|
||||||
sleep(5)
|
sleep(5)
|
||||||
del llm
|
del llm
|
||||||
cleanup_env_and_memory()
|
cleanup_env_and_memory()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@@ -231,8 +205,7 @@ if __name__ == "__main__":
|
|||||||
from multiprocessing import Process
|
from multiprocessing import Process
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for local_dp_rank, global_dp_rank in enumerate(
|
for local_dp_rank, global_dp_rank in enumerate(range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
|
||||||
range(node_rank * dp_per_node, (node_rank + 1) * dp_per_node)):
|
|
||||||
proc = Process(
|
proc = Process(
|
||||||
target=main,
|
target=main,
|
||||||
args=(
|
args=(
|
||||||
@@ -255,9 +228,7 @@ if __name__ == "__main__":
|
|||||||
for proc in procs:
|
for proc in procs:
|
||||||
proc.join(timeout=900)
|
proc.join(timeout=900)
|
||||||
if proc.exitcode is None:
|
if proc.exitcode is None:
|
||||||
print(
|
print(f"Killing process {proc.pid} that didn't stop within 15 minutes.")
|
||||||
f"Killing process {proc.pid} that didn't stop within 15 minutes."
|
|
||||||
)
|
|
||||||
proc.kill()
|
proc.kill()
|
||||||
exit_code = 1
|
exit_code = 1
|
||||||
elif proc.exitcode:
|
elif proc.exitcode:
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ def clean_up():
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
|
||||||
destroy_distributed_environment, destroy_model_parallel)
|
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
destroy_distributed_environment()
|
destroy_distributed_environment()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -44,8 +44,10 @@ def run_prefill(prefill_done, process_close):
|
|||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, how are you today?", "Hi, what is your name?",
|
"Hello, how are you today?",
|
||||||
"Tell me a very long story.", "what is your favourite book?"
|
"Hi, what is your name?",
|
||||||
|
"Tell me a very long story.",
|
||||||
|
"what is your favourite book?",
|
||||||
]
|
]
|
||||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||||
|
|
||||||
@@ -55,22 +57,16 @@ def run_prefill(prefill_done, process_close):
|
|||||||
kv_port="30000",
|
kv_port="30000",
|
||||||
engine_id="0",
|
engine_id="0",
|
||||||
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
|
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
|
||||||
kv_connector_extra_config={
|
kv_connector_extra_config={"prefill": {"dp_size": 1, "tp_size": 1}, "decode": {"dp_size": 1, "tp_size": 1}},
|
||||||
"prefill": {
|
)
|
||||||
"dp_size": 1,
|
|
||||||
"tp_size": 1
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"dp_size": 1,
|
|
||||||
"tp_size": 1
|
|
||||||
}
|
|
||||||
})
|
|
||||||
# Set NPU memory utilization to 0.8
|
# Set NPU memory utilization to 0.8
|
||||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
llm = LLM(
|
||||||
kv_transfer_config=ktc,
|
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
max_model_len=2000,
|
kv_transfer_config=ktc,
|
||||||
gpu_memory_utilization=0.8,
|
max_model_len=2000,
|
||||||
tensor_parallel_size=1)
|
gpu_memory_utilization=0.8,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
llm.generate(prompts, sampling_params)
|
llm.generate(prompts, sampling_params)
|
||||||
print("Prefill node is finished.")
|
print("Prefill node is finished.")
|
||||||
@@ -96,8 +92,10 @@ def run_decode(prefill_done):
|
|||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
"Hello, how are you today?", "Hi, what is your name?",
|
"Hello, how are you today?",
|
||||||
"Tell me a very long story.", "what is your favourite book?"
|
"Hi, what is your name?",
|
||||||
|
"Tell me a very long story.",
|
||||||
|
"what is your favourite book?",
|
||||||
]
|
]
|
||||||
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||||
|
|
||||||
@@ -107,22 +105,16 @@ def run_decode(prefill_done):
|
|||||||
kv_port="30100",
|
kv_port="30100",
|
||||||
engine_id="1",
|
engine_id="1",
|
||||||
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
|
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector",
|
||||||
kv_connector_extra_config={
|
kv_connector_extra_config={"prefill": {"dp_size": 1, "tp_size": 1}, "decode": {"dp_size": 1, "tp_size": 1}},
|
||||||
"prefill": {
|
)
|
||||||
"dp_size": 1,
|
|
||||||
"tp_size": 1
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"dp_size": 1,
|
|
||||||
"tp_size": 1
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
llm = LLM(
|
||||||
kv_transfer_config=ktc,
|
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||||
max_model_len=2000,
|
kv_transfer_config=ktc,
|
||||||
gpu_memory_utilization=0.8,
|
max_model_len=2000,
|
||||||
tensor_parallel_size=1)
|
gpu_memory_utilization=0.8,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for the producer to start the consumer
|
# Wait for the producer to start the consumer
|
||||||
print("Waiting for prefill node to finish...")
|
print("Waiting for prefill node to finish...")
|
||||||
@@ -141,16 +133,18 @@ def run_decode(prefill_done):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mp.get_context('spawn')
|
mp.get_context("spawn")
|
||||||
|
|
||||||
prefill_done = Event()
|
prefill_done = Event()
|
||||||
process_close = Event()
|
process_close = Event()
|
||||||
prefill_process = Process(target=run_prefill,
|
prefill_process = Process(
|
||||||
args=(
|
target=run_prefill,
|
||||||
prefill_done,
|
args=(
|
||||||
process_close,
|
prefill_done,
|
||||||
))
|
process_close,
|
||||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
),
|
||||||
|
)
|
||||||
|
decode_process = Process(target=run_decode, args=(prefill_done,))
|
||||||
|
|
||||||
# Start prefill node
|
# Start prefill node
|
||||||
prefill_process.start()
|
prefill_process.start()
|
||||||
|
|||||||
@@ -25,22 +25,24 @@ from vllm import LLM
|
|||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
def get_detailed_instruct(task_description: str, query: str) -> str:
|
def get_detailed_instruct(task_description: str, query: str) -> str:
|
||||||
return f'Instruct: {task_description}\nQuery:{query}'
|
return f"Instruct: {task_description}\nQuery:{query}"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Each query must come with a one-sentence instruction that describes the task
|
# Each query must come with a one-sentence instruction that describes the task
|
||||||
task = 'Given a web search query, retrieve relevant passages that answer the query'
|
task = "Given a web search query, retrieve relevant passages that answer the query"
|
||||||
|
|
||||||
queries = [
|
queries = [
|
||||||
get_detailed_instruct(task, 'What is the capital of China?'),
|
get_detailed_instruct(task, "What is the capital of China?"),
|
||||||
get_detailed_instruct(task, 'Explain gravity')
|
get_detailed_instruct(task, "Explain gravity"),
|
||||||
]
|
]
|
||||||
# No need to add instruction for retrieval documents
|
# No need to add instruction for retrieval documents
|
||||||
documents = [
|
documents = [
|
||||||
"The capital of China is Beijing.",
|
"The capital of China is Beijing.",
|
||||||
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
|
"Gravity is a force that attracts two bodies towards each other. "
|
||||||
|
"It gives weight to physical objects and is responsible for the movement of planets around the sun.",
|
||||||
]
|
]
|
||||||
input_texts = queries + documents
|
input_texts = queries + documents
|
||||||
|
|
||||||
@@ -49,7 +51,7 @@ def main():
|
|||||||
outputs = model.embed(input_texts)
|
outputs = model.embed(input_texts)
|
||||||
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
|
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
|
||||||
# Calculate the similarity scores between the first two queries and the last two documents
|
# Calculate the similarity scores between the first two queries and the last two documents
|
||||||
scores = (embeddings[:2] @ embeddings[2:].T)
|
scores = embeddings[:2] @ embeddings[2:].T
|
||||||
print(scores.tolist())
|
print(scores.tolist())
|
||||||
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
|
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
|
||||||
|
|
||||||
|
|||||||
@@ -63,10 +63,13 @@ from multiprocessing import Process
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.distributed.parallel_state import ( # noqa E402
|
from vllm.distributed.parallel_state import ( # noqa E402
|
||||||
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
destroy_distributed_environment,
|
||||||
from safetensors.torch import load_file
|
destroy_model_parallel,
|
||||||
|
get_tp_group,
|
||||||
|
)
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
|
|
||||||
@@ -101,7 +104,6 @@ def load_and_merge_safetensors(directory):
|
|||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="External launcher Inference")
|
parser = argparse.ArgumentParser(description="External launcher Inference")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@@ -109,60 +111,41 @@ def parse_args():
|
|||||||
default="Qwen/Qwen3-0.6B",
|
default="Qwen/Qwen3-0.6B",
|
||||||
help="Model name or path",
|
help="Model name or path",
|
||||||
)
|
)
|
||||||
parser.add_argument("--tp-size",
|
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||||
type=int,
|
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
|
||||||
default=1,
|
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
|
||||||
help="Tensor parallel size")
|
parser.add_argument("--proc-per-node", type=int, default=1, help="Number of processes per node")
|
||||||
parser.add_argument("--node-size",
|
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
|
||||||
type=int,
|
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
|
||||||
default=1,
|
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
|
||||||
help="Total number of nodes")
|
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
|
||||||
parser.add_argument("--node-rank",
|
parser.add_argument(
|
||||||
type=int,
|
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
|
||||||
default=0,
|
)
|
||||||
help="Rank of the current node")
|
parser.add_argument("--enable-sleep-mode", action="store_true", help="Enable sleep mode for the engine.")
|
||||||
parser.add_argument("--proc-per-node",
|
parser.add_argument(
|
||||||
type=int,
|
"--temperature", type=float, default=0.8, help="Float that controls the randomness of the sampling."
|
||||||
default=1,
|
)
|
||||||
help="Number of processes per node")
|
parser.add_argument(
|
||||||
parser.add_argument("--master-addr",
|
"--model-weight-gib",
|
||||||
type=str,
|
type=float,
|
||||||
default="",
|
default=None,
|
||||||
help="Master node IP address")
|
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).",
|
||||||
parser.add_argument("--master-port",
|
)
|
||||||
type=int,
|
parser.add_argument(
|
||||||
default=0,
|
"--sleep-mode-level",
|
||||||
help="Master node port")
|
type=int,
|
||||||
parser.add_argument("--enforce-eager",
|
choices=[1, 2],
|
||||||
action="store_true",
|
default=1,
|
||||||
help="Enforce eager mode execution.")
|
help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.",
|
||||||
parser.add_argument("--trust-remote-code",
|
)
|
||||||
action="store_true",
|
|
||||||
help="Trust remote code.")
|
|
||||||
parser.add_argument("--enable-expert-parallel",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable expert parallel, used in MOE models.")
|
|
||||||
parser.add_argument("--enable-sleep-mode",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable sleep mode for the engine.")
|
|
||||||
parser.add_argument("--temperature",
|
|
||||||
type=float,
|
|
||||||
default=0.8,
|
|
||||||
help="Float that controls the randomness of the sampling.")
|
|
||||||
parser.add_argument("--model-weight-gib",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
|
|
||||||
parser.add_argument("--sleep-mode-level",
|
|
||||||
type=int,
|
|
||||||
choices=[1, 2],
|
|
||||||
default=1,
|
|
||||||
help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.enable_sleep_mode:
|
if args.enable_sleep_mode:
|
||||||
if args.model_weight_gib is None or args.temperature != 0:
|
if args.model_weight_gib is None or args.temperature != 0:
|
||||||
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
|
parser.error(
|
||||||
|
"model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set."
|
||||||
|
)
|
||||||
if args.model_weight_gib <= 0:
|
if args.model_weight_gib <= 0:
|
||||||
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
||||||
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
||||||
@@ -220,7 +203,7 @@ def main(
|
|||||||
enable_sleep_mode=enable_sleep_mode,
|
enable_sleep_mode=enable_sleep_mode,
|
||||||
)
|
)
|
||||||
tp_ranks = get_tp_group().ranks
|
tp_ranks = get_tp_group().ranks
|
||||||
print(f'TP RANKS: {tp_ranks}')
|
print(f"TP RANKS: {tp_ranks}")
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
@@ -231,7 +214,7 @@ def main(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||||
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
|
print(f"Freed memory: {freed_bytes / 1024**3:.2f} GiB")
|
||||||
# now the freed memory should be larger than the model weights
|
# now the freed memory should be larger than the model weights
|
||||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||||
|
|
||||||
@@ -257,8 +240,7 @@ def main(
|
|||||||
break
|
break
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
|
print(f"Global rank: {rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
# Give engines time to pause their processing loops before exiting.
|
# Give engines time to pause their processing loops before exiting.
|
||||||
sleep(5)
|
sleep(5)
|
||||||
@@ -294,25 +276,26 @@ if __name__ == "__main__":
|
|||||||
world_size = node_size * proc_per_node
|
world_size = node_size * proc_per_node
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for local_rank, rank in enumerate(
|
for local_rank, rank in enumerate(range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
|
||||||
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
|
proc = Process(
|
||||||
proc = Process(target=main,
|
target=main,
|
||||||
args=(
|
args=(
|
||||||
local_rank,
|
local_rank,
|
||||||
rank,
|
rank,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
args.model_weight_gib,
|
args.model_weight_gib,
|
||||||
args.model,
|
args.model,
|
||||||
world_size,
|
world_size,
|
||||||
tp_size,
|
tp_size,
|
||||||
args.enable_expert_parallel,
|
args.enable_expert_parallel,
|
||||||
args.enforce_eager,
|
args.enforce_eager,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
args.enable_sleep_mode,
|
args.enable_sleep_mode,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
args.sleep_mode_level,
|
args.sleep_mode_level,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
@@ -320,9 +303,7 @@ if __name__ == "__main__":
|
|||||||
for proc in procs:
|
for proc in procs:
|
||||||
proc.join(timeout=600)
|
proc.join(timeout=600)
|
||||||
if proc.exitcode is None:
|
if proc.exitcode is None:
|
||||||
print(
|
print(f"Killing process {proc.pid} that didn't stop within 30 minutes.")
|
||||||
f"Killing process {proc.pid} that didn't stop within 30 minutes."
|
|
||||||
)
|
|
||||||
proc.kill()
|
proc.kill()
|
||||||
exit_code = 1
|
exit_code = 1
|
||||||
elif proc.exitcode:
|
elif proc.exitcode:
|
||||||
|
|||||||
@@ -24,12 +24,13 @@ For most models, the prompt format should follow corresponding examples
|
|||||||
on HuggingFace model repository.
|
on HuggingFace model repository.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import librosa # type: ignore
|
import librosa # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise Exception("Can't import librosa, please ensure it's installed")
|
raise Exception("Can't import librosa, please ensure it's installed")
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|||||||
|
|
||||||
|
|
||||||
def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str):
|
def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str):
|
||||||
use_vllm_audio_assert = True if audio_path1 == "mary_had_lamb" and audio_path2 == "winning_call" else False
|
use_vllm_audio_assert = audio_path1 == "mary_had_lamb" and audio_path2 == "winning_call"
|
||||||
if use_vllm_audio_assert:
|
if use_vllm_audio_assert:
|
||||||
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
||||||
else:
|
else:
|
||||||
@@ -48,22 +49,22 @@ def prepare_inputs(audio_count: int, audio_path1: str, audio_path2: str):
|
|||||||
|
|
||||||
question_per_audio_count = {
|
question_per_audio_count = {
|
||||||
1: "What is recited in the audio?",
|
1: "What is recited in the audio?",
|
||||||
2: "What sport and what nursery rhyme are referenced?"
|
2: "What sport and what nursery rhyme are referenced?",
|
||||||
}
|
}
|
||||||
|
|
||||||
audio_in_prompt = "".join([
|
audio_in_prompt = "".join([f"Audio {idx + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)])
|
||||||
f"Audio {idx+1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
|
|
||||||
for idx in range(audio_count)
|
|
||||||
])
|
|
||||||
question = question_per_audio_count[audio_count]
|
question = question_per_audio_count[audio_count]
|
||||||
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
prompt = (
|
||||||
"<|im_start|>user\n"
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||||||
f"{audio_in_prompt}{question}<|im_end|>\n"
|
"<|im_start|>user\n"
|
||||||
"<|im_start|>assistant\n")
|
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
|
||||||
mm_data = {
|
mm_data = {
|
||||||
"audio":
|
"audio": audio_assets
|
||||||
audio_assets if not use_vllm_audio_assert else [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
|
if not use_vllm_audio_assert
|
||||||
|
else [asset.audio_and_sample_rate for asset in audio_assets[:audio_count]]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Merge text prompt and audio data into inputs
|
# Merge text prompt and audio data into inputs
|
||||||
@@ -76,17 +77,17 @@ def main(audio_count: int, audio_path1: str, audio_path2: str):
|
|||||||
# lower-end GPUs.
|
# lower-end GPUs.
|
||||||
# Unless specified, these settings have been tested to work on a single L4.
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
# `limit_mm_per_prompt`: the max num items for each modality per prompt.
|
# `limit_mm_per_prompt`: the max num items for each modality per prompt.
|
||||||
llm = LLM(model="Qwen/Qwen2-Audio-7B-Instruct",
|
llm = LLM(
|
||||||
max_model_len=4096,
|
model="Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
max_num_seqs=5,
|
max_model_len=4096,
|
||||||
limit_mm_per_prompt={"audio": audio_count},
|
max_num_seqs=5,
|
||||||
enforce_eager=True)
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
inputs = prepare_inputs(audio_count, audio_path1, audio_path2)
|
inputs = prepare_inputs(audio_count, audio_path1, audio_path2)
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.2,
|
sampling_params = SamplingParams(temperature=0.2, max_tokens=64, stop_token_ids=None)
|
||||||
max_tokens=64,
|
|
||||||
stop_token_ids=None)
|
|
||||||
|
|
||||||
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
outputs = llm.generate(inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
@@ -96,7 +97,9 @@ def main(audio_count: int, audio_path1: str, audio_path2: str):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Arguments of rank table generator", )
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Arguments of rank table generator",
|
||||||
|
)
|
||||||
parser.add_argument("--audio-path1", type=str, default="mary_had_lamb")
|
parser.add_argument("--audio-path1", type=str, default="mary_had_lamb")
|
||||||
parser.add_argument("--audio-path2", type=str, default="winning_call")
|
parser.add_argument("--audio-path2", type=str, default="winning_call")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
@@ -11,14 +11,14 @@ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument('--input_len', type=int, default=1024)
|
parser.add_argument("--input_len", type=int, default=1024)
|
||||||
parser.add_argument('--output_len', type=int, default=128)
|
parser.add_argument("--output_len", type=int, default=128)
|
||||||
parser.add_argument('--bs', type=int, default=1)
|
parser.add_argument("--bs", type=int, default=1)
|
||||||
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite")
|
parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-V2-Lite")
|
||||||
parser.add_argument('--tp', type=int, default=2)
|
parser.add_argument("--tp", type=int, default=2)
|
||||||
parser.add_argument('--pcp', type=int, default=2)
|
parser.add_argument("--pcp", type=int, default=2)
|
||||||
parser.add_argument('--dcp', type=int, default=1)
|
parser.add_argument("--dcp", type=int, default=1)
|
||||||
parser.add_argument('--iter_times', type=int, default=1)
|
parser.add_argument("--iter_times", type=int, default=1)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -26,10 +26,10 @@ if __name__ == "__main__":
|
|||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"Hello, my name is Tom, I am",
|
"Hello, my name is Tom, I am",
|
||||||
"The president of United States is",
|
"The president of United States is",
|
||||||
"AI future is"
|
"AI future is",
|
||||||
]
|
]
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=args.output_len)
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model_path,
|
model=args.model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@@ -44,7 +44,7 @@ if __name__ == "__main__":
|
|||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
max_num_seqs=1,
|
max_num_seqs=1,
|
||||||
block_size=128,
|
block_size=128,
|
||||||
gpu_memory_utilization=0.9
|
gpu_memory_utilization=0.9,
|
||||||
)
|
)
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|||||||
@@ -37,11 +37,13 @@ def main():
|
|||||||
# Create a sampling params object.
|
# Create a sampling params object.
|
||||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
|
llm = LLM(
|
||||||
tensor_parallel_size=2,
|
model="deepseek-ai/DeepSeek-V2-Lite",
|
||||||
enforce_eager=True,
|
tensor_parallel_size=2,
|
||||||
trust_remote_code=True,
|
enforce_eager=True,
|
||||||
max_model_len=1024)
|
trust_remote_code=True,
|
||||||
|
max_model_len=1024,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate texts from the prompts.
|
# Generate texts from the prompts.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@@ -25,11 +25,12 @@ from vllm.utils.mem_constants import GiB_bytes
|
|||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
prompt = "How are you?"
|
prompt = "How are you?"
|
||||||
|
|
||||||
free, total = torch.npu.mem_get_info()
|
free, total = torch.npu.mem_get_info()
|
||||||
print(f"Free memory before sleep: {free / 1024 ** 3:.2f} GiB")
|
print(f"Free memory before sleep: {free / 1024**3:.2f} GiB")
|
||||||
# record npu memory use baseline in case other process is running
|
# record npu memory use baseline in case other process is running
|
||||||
used_bytes_baseline = total - free
|
used_bytes_baseline = total - free
|
||||||
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)
|
llm = LLM("Qwen/Qwen2.5-0.5B-Instruct", enable_sleep_mode=True)
|
||||||
@@ -39,9 +40,7 @@ def main():
|
|||||||
llm.sleep(level=1)
|
llm.sleep(level=1)
|
||||||
|
|
||||||
free_npu_bytes_after_sleep, total = torch.npu.mem_get_info()
|
free_npu_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||||
print(
|
print(f"Free memory after sleep: {free_npu_bytes_after_sleep / 1024**3:.2f} GiB")
|
||||||
f"Free memory after sleep: {free_npu_bytes_after_sleep / 1024 ** 3:.2f} GiB"
|
|
||||||
)
|
|
||||||
used_bytes = total - free_npu_bytes_after_sleep - used_bytes_baseline
|
used_bytes = total - free_npu_bytes_after_sleep - used_bytes_baseline
|
||||||
# now the memory usage should be less than the model weights
|
# now the memory usage should be less than the model weights
|
||||||
# (0.5B model, 1GiB weights)
|
# (0.5B model, 1GiB weights)
|
||||||
|
|||||||
@@ -63,19 +63,21 @@ from multiprocessing import Process
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.distributed.parallel_state import ( # noqa E402
|
from vllm.distributed.parallel_state import ( # noqa E402
|
||||||
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
destroy_distributed_environment,
|
||||||
from safetensors.torch import load_file
|
destroy_model_parallel,
|
||||||
|
get_tp_group,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||||
from vllm.utils.mem_constants import GiB_bytes
|
from vllm.utils.mem_constants import GiB_bytes
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
|
|
||||||
from vllm.model_executor.model_loader.utils import \
|
|
||||||
process_weights_after_loading
|
|
||||||
|
|
||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
def patch_vllm_moe_model_weight_loader(model):
|
def patch_vllm_moe_model_weight_loader(model):
|
||||||
# Define MLP attribute mapping for different model types
|
# Define MLP attribute mapping for different model types
|
||||||
|
|
||||||
@@ -92,6 +94,7 @@ def patch_vllm_moe_model_weight_loader(model):
|
|||||||
if "w13_weight" in name or "w2_weight" in name:
|
if "w13_weight" in name or "w2_weight" in name:
|
||||||
param.weight_loader = mlp.experts.weight_loader
|
param.weight_loader = mlp.experts.weight_loader
|
||||||
|
|
||||||
|
|
||||||
def load_and_merge_safetensors(directory):
|
def load_and_merge_safetensors(directory):
|
||||||
merged_dict = {}
|
merged_dict = {}
|
||||||
|
|
||||||
@@ -99,7 +102,7 @@ def load_and_merge_safetensors(directory):
|
|||||||
raise ValueError(f"directory is not exist : {directory}")
|
raise ValueError(f"directory is not exist : {directory}")
|
||||||
|
|
||||||
for filename in os.listdir(directory):
|
for filename in os.listdir(directory):
|
||||||
if filename.endswith('.safetensors'):
|
if filename.endswith(".safetensors"):
|
||||||
file_path = os.path.join(directory, filename)
|
file_path = os.path.join(directory, filename)
|
||||||
print(f"loading file: {file_path}")
|
print(f"loading file: {file_path}")
|
||||||
|
|
||||||
@@ -108,8 +111,8 @@ def load_and_merge_safetensors(directory):
|
|||||||
|
|
||||||
return merged_dict
|
return merged_dict
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="External launcher Inference")
|
parser = argparse.ArgumentParser(description="External launcher Inference")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
@@ -117,55 +120,34 @@ def parse_args():
|
|||||||
default="Qwen/Qwen3-0.6B",
|
default="Qwen/Qwen3-0.6B",
|
||||||
help="Model name or path",
|
help="Model name or path",
|
||||||
)
|
)
|
||||||
parser.add_argument("--tp-size",
|
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||||
type=int,
|
parser.add_argument("--node-size", type=int, default=1, help="Total number of nodes")
|
||||||
default=1,
|
parser.add_argument("--node-rank", type=int, default=0, help="Rank of the current node")
|
||||||
help="Tensor parallel size")
|
parser.add_argument("--proc-per-node", type=int, default=1, help="Number of processes per node")
|
||||||
parser.add_argument("--node-size",
|
parser.add_argument("--master-addr", type=str, default="", help="Master node IP address")
|
||||||
type=int,
|
parser.add_argument("--master-port", type=int, default=0, help="Master node port")
|
||||||
default=1,
|
parser.add_argument("--enforce-eager", action="store_true", help="Enforce eager mode execution.")
|
||||||
help="Total number of nodes")
|
parser.add_argument("--trust-remote-code", action="store_true", help="Trust remote code.")
|
||||||
parser.add_argument("--node-rank",
|
parser.add_argument(
|
||||||
type=int,
|
"--enable-expert-parallel", action="store_true", help="Enable expert parallel, used in MOE models."
|
||||||
default=0,
|
)
|
||||||
help="Rank of the current node")
|
parser.add_argument("--enable-sleep-mode", action="store_true", help="Enable sleep mode for the engine.")
|
||||||
parser.add_argument("--proc-per-node",
|
parser.add_argument(
|
||||||
type=int,
|
"--temperature", type=float, default=0.8, help="Float that controls the randomness of the sampling."
|
||||||
default=1,
|
)
|
||||||
help="Number of processes per node")
|
parser.add_argument(
|
||||||
parser.add_argument("--master-addr",
|
"--model-weight-gib",
|
||||||
type=str,
|
type=float,
|
||||||
default="",
|
default=None,
|
||||||
help="Master node IP address")
|
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).",
|
||||||
parser.add_argument("--master-port",
|
)
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Master node port")
|
|
||||||
parser.add_argument("--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="Enforce eager mode execution.")
|
|
||||||
parser.add_argument("--trust-remote-code",
|
|
||||||
action="store_true",
|
|
||||||
help="Trust remote code.")
|
|
||||||
parser.add_argument("--enable-expert-parallel",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable expert parallel, used in MOE models.")
|
|
||||||
parser.add_argument("--enable-sleep-mode",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable sleep mode for the engine.")
|
|
||||||
parser.add_argument("--temperature",
|
|
||||||
type=float,
|
|
||||||
default=0.8,
|
|
||||||
help="Float that controls the randomness of the sampling.")
|
|
||||||
parser.add_argument("--model-weight-gib",
|
|
||||||
type=float,
|
|
||||||
default=None,
|
|
||||||
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.enable_sleep_mode:
|
if args.enable_sleep_mode:
|
||||||
if args.model_weight_gib is None or args.temperature != 0:
|
if args.model_weight_gib is None or args.temperature != 0:
|
||||||
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
|
parser.error(
|
||||||
|
"model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set."
|
||||||
|
)
|
||||||
if args.model_weight_gib <= 0:
|
if args.model_weight_gib <= 0:
|
||||||
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
||||||
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
||||||
@@ -219,7 +201,7 @@ def main(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
distributed_executor_backend="external_launcher",
|
distributed_executor_backend="external_launcher",
|
||||||
seed=0,
|
seed=0,
|
||||||
gpu_memory_utilization = 0.95,
|
gpu_memory_utilization=0.95,
|
||||||
enable_sleep_mode=enable_sleep_mode,
|
enable_sleep_mode=enable_sleep_mode,
|
||||||
)
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
@@ -231,7 +213,7 @@ def main(
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||||
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
|
print(f"Freed memory: {freed_bytes / 1024**3:.2f} GiB")
|
||||||
# now the freed memory should be larger than the model weights
|
# now the freed memory should be larger than the model weights
|
||||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||||
|
|
||||||
@@ -242,9 +224,9 @@ def main(
|
|||||||
patch_vllm_moe_model_weight_loader(runmodel)
|
patch_vllm_moe_model_weight_loader(runmodel)
|
||||||
sd = load_and_merge_safetensors(model_path)
|
sd = load_and_merge_safetensors(model_path)
|
||||||
runmodel.load_weights(sd.items())
|
runmodel.load_weights(sd.items())
|
||||||
print('load state dict done')
|
print("load state dict done")
|
||||||
tp_ranks = get_tp_group().ranks
|
tp_ranks = get_tp_group().ranks
|
||||||
print(f'TP RANKS: {tp_ranks}')
|
print(f"TP RANKS: {tp_ranks}")
|
||||||
|
|
||||||
vllm_config = llm.llm_engine.vllm_config.model_config
|
vllm_config = llm.llm_engine.vllm_config.model_config
|
||||||
device = next(runmodel.parameters()).device
|
device = next(runmodel.parameters()).device
|
||||||
@@ -262,8 +244,7 @@ def main(
|
|||||||
break
|
break
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
|
print(f"Global rank: {rank}, Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
f"Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
# Give engines time to pause their processing loops before exiting.
|
# Give engines time to pause their processing loops before exiting.
|
||||||
sleep(5)
|
sleep(5)
|
||||||
@@ -299,24 +280,25 @@ if __name__ == "__main__":
|
|||||||
world_size = node_size * proc_per_node
|
world_size = node_size * proc_per_node
|
||||||
|
|
||||||
procs = []
|
procs = []
|
||||||
for local_rank, rank in enumerate(
|
for local_rank, rank in enumerate(range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
|
||||||
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
|
proc = Process(
|
||||||
proc = Process(target=main,
|
target=main,
|
||||||
args=(
|
args=(
|
||||||
local_rank,
|
local_rank,
|
||||||
rank,
|
rank,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
args.model_weight_gib,
|
args.model_weight_gib,
|
||||||
args.model,
|
args.model,
|
||||||
world_size,
|
world_size,
|
||||||
tp_size,
|
tp_size,
|
||||||
args.enable_expert_parallel,
|
args.enable_expert_parallel,
|
||||||
args.enforce_eager,
|
args.enforce_eager,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
args.enable_sleep_mode,
|
args.enable_sleep_mode,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
@@ -324,9 +306,7 @@ if __name__ == "__main__":
|
|||||||
for proc in procs:
|
for proc in procs:
|
||||||
proc.join(timeout=600)
|
proc.join(timeout=600)
|
||||||
if proc.exitcode is None:
|
if proc.exitcode is None:
|
||||||
print(
|
print(f"Killing process {proc.pid} that didn't stop within 30 minutes.")
|
||||||
f"Killing process {proc.pid} that didn't stop within 30 minutes."
|
|
||||||
)
|
|
||||||
proc.kill()
|
proc.kill()
|
||||||
exit_code = 1
|
exit_code = 1
|
||||||
elif proc.exitcode:
|
elif proc.exitcode:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ Run:
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
|
|
||||||
@@ -37,16 +36,12 @@ def get_prompt_embeds(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
embedding_layer: torch.nn.Module,
|
embedding_layer: torch.nn.Module,
|
||||||
):
|
):
|
||||||
token_ids = tokenizer.apply_chat_template(
|
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt")
|
||||||
chat, add_generation_prompt=True, return_tensors="pt"
|
|
||||||
)
|
|
||||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
def single_prompt_inference(
|
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
|
||||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
|
||||||
):
|
|
||||||
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
@@ -63,18 +58,14 @@ def single_prompt_inference(
|
|||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
def batch_prompt_inference(
|
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
|
||||||
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
|
||||||
):
|
|
||||||
chats = [
|
chats = [
|
||||||
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||||
[{"role": "user", "content": "When is the day longest during the year?"}],
|
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||||
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt_embeds_list = [
|
prompt_embeds_list = [get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats]
|
||||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||||
PreTrainedTokenizer)
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
|
||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
@@ -17,27 +16,21 @@ def init_tokenizer_and_llm(model_name: str):
|
|||||||
return tokenizer, embedding_layer, llm
|
return tokenizer, embedding_layer, llm
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_embeds(chat: list[dict[str,
|
def get_prompt_embeds(chat: list[dict[str, str]], tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
|
||||||
str]], tokenizer: PreTrainedTokenizer,
|
token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt")
|
||||||
embedding_layer: torch.nn.Module):
|
|
||||||
token_ids = tokenizer.apply_chat_template(chat,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
return_tensors='pt')
|
|
||||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
|
||||||
embedding_layer: torch.nn.Module):
|
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||||
chat = [{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Please tell me about the capital of France."
|
|
||||||
}]
|
|
||||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
outputs = llm.generate({
|
outputs = llm.generate(
|
||||||
"prompt_embeds": prompt_embeds,
|
{
|
||||||
})
|
"prompt_embeds": prompt_embeds,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print("\n[Single Inference Output]")
|
print("\n[Single Inference Output]")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
@@ -46,34 +39,22 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
|||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module):
|
||||||
embedding_layer: torch.nn.Module):
|
chats = [
|
||||||
chats = [[{
|
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||||
"role": "user",
|
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||||
"content": "Please tell me about the capital of France."
|
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
||||||
}],
|
|
||||||
[{
|
|
||||||
"role": "user",
|
|
||||||
"content": "When is the day longest during the year?"
|
|
||||||
}],
|
|
||||||
[{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Where is bigger, the moon or the sun?"
|
|
||||||
}]]
|
|
||||||
|
|
||||||
prompt_embeds_list = [
|
|
||||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs = llm.generate([{
|
prompt_embeds_list = [get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats]
|
||||||
"prompt_embeds": embeds
|
|
||||||
} for embeds in prompt_embeds_list])
|
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
||||||
|
|
||||||
print("\n[Batch Inference Outputs]")
|
print("\n[Batch Inference Outputs]")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
for i, o in enumerate(outputs):
|
for i, o in enumerate(outputs):
|
||||||
print(f"Q{i+1}: {chats[i][0]['content']}")
|
print(f"Q{i + 1}: {chats[i][0]['content']}")
|
||||||
print(f"A{i+1}: {o.outputs[0].text}\n")
|
print(f"A{i + 1}: {o.outputs[0].text}\n")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationStrategy, QuantizationType
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoModelForCausalLM, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, \
|
|
||||||
AutoTokenizer, AutoProcessor, AutoConfig, AutoImageProcessor
|
|
||||||
|
|
||||||
from llmcompressor import oneshot
|
from llmcompressor import oneshot
|
||||||
from llmcompressor.modifiers.awq import AWQModifier
|
from llmcompressor.modifiers.awq import AWQModifier
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
|
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
|
||||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme, QuantizationType, QuantizationStrategy
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
W8A8_W_cha_A_ten_static_symmetric = {
|
W8A8_W_cha_A_ten_static_symmetric = {
|
||||||
"group_0": QuantizationScheme(
|
"group_0": QuantizationScheme(
|
||||||
targets=["Linear"],
|
targets=["Linear"],
|
||||||
weights=QuantizationArgs(
|
weights=QuantizationArgs(
|
||||||
num_bits=8,
|
num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.CHANNEL, symmetric=True, dynamic=False
|
||||||
type=QuantizationType.INT,
|
|
||||||
strategy=QuantizationStrategy.CHANNEL,
|
|
||||||
symmetric=True,
|
|
||||||
dynamic=False
|
|
||||||
),
|
),
|
||||||
input_activations=QuantizationArgs(
|
input_activations=QuantizationArgs(
|
||||||
num_bits=8,
|
num_bits=8, type=QuantizationType.INT, strategy=QuantizationStrategy.TENSOR, symmetric=True, dynamic=False
|
||||||
type=QuantizationType.INT,
|
|
||||||
strategy=QuantizationStrategy.TENSOR,
|
|
||||||
symmetric=True,
|
|
||||||
dynamic=False
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
@@ -53,19 +47,19 @@ TOKENIZER_DICT = {
|
|||||||
|
|
||||||
def load_environment_variables():
|
def load_environment_variables():
|
||||||
env_vars = {
|
env_vars = {
|
||||||
'model_path': "Qwen/Qwen3-32B",
|
"model_path": "Qwen/Qwen3-32B",
|
||||||
'export_path': "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
|
"export_path": "/llm-compressor/export/GPTQ/W8A8_W_cha_A_ten_static_symmetric",
|
||||||
'modifier': "GPTQ",
|
"modifier": "GPTQ",
|
||||||
'schemes': "W8A8_W_cha_A_ten_static_symmetric",
|
"schemes": "W8A8_W_cha_A_ten_static_symmetric",
|
||||||
'calib_prompt_path': "HuggingFaceH4/ultrachat_200k"
|
"calib_prompt_path": "HuggingFaceH4/ultrachat_200k",
|
||||||
}
|
}
|
||||||
|
|
||||||
# verify export model path
|
# verify export model path
|
||||||
if env_vars['export_path'] is None:
|
if env_vars["export_path"] is None:
|
||||||
env_vars['export_path'] = env_vars['model_path'].rstrip("/") + "-" + env_vars['modifier']
|
env_vars["export_path"] = env_vars["model_path"].rstrip("/") + "-" + env_vars["modifier"]
|
||||||
if env_vars['schemes'] is not None:
|
if env_vars["schemes"] is not None:
|
||||||
env_vars['export_path'] += "-" + env_vars['schemes']
|
env_vars["export_path"] += "-" + env_vars["schemes"]
|
||||||
os.makedirs(env_vars['export_path'], exist_ok=True)
|
os.makedirs(env_vars["export_path"], exist_ok=True)
|
||||||
|
|
||||||
return env_vars
|
return env_vars
|
||||||
|
|
||||||
@@ -74,19 +68,17 @@ def load_calibration_text_dataset(calib_prompt_path, tokenizer):
|
|||||||
# Load dataset
|
# Load dataset
|
||||||
for f in os.listdir(calib_prompt_path):
|
for f in os.listdir(calib_prompt_path):
|
||||||
print(f)
|
print(f)
|
||||||
if any(f.lower().endswith('.jsonl') for f in os.listdir(calib_prompt_path)):
|
if any(f.lower().endswith(".jsonl") for f in os.listdir(calib_prompt_path)):
|
||||||
ds = load_dataset('json', data_dir=calib_prompt_path, split='validation')
|
ds = load_dataset("json", data_dir=calib_prompt_path, split="validation")
|
||||||
elif any(f.lower().endswith('.parquet') for f in os.listdir(calib_prompt_path)):
|
elif any(f.lower().endswith(".parquet") for f in os.listdir(calib_prompt_path)):
|
||||||
ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]")
|
ds = load_dataset("parquet", data_dir=calib_prompt_path, split="train[:512]")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported calibration file format: {}".format(
|
raise ValueError("Unsupported calibration file format: {}".format(calib_prompt_path.split(".")[-1]))
|
||||||
calib_prompt_path.split('.')[-1]))
|
|
||||||
|
|
||||||
# Preprocess dataset
|
# Preprocess dataset
|
||||||
def preprocess(example):
|
def preprocess(example):
|
||||||
if tokenizer.chat_template is not None:
|
if tokenizer.chat_template is not None:
|
||||||
return {"text": tokenizer.apply_chat_template(
|
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
||||||
example["messages"], tokenize=False)}
|
|
||||||
else:
|
else:
|
||||||
return {"text": example["messages"]}
|
return {"text": example["messages"]}
|
||||||
|
|
||||||
@@ -118,8 +110,8 @@ def quantize_model(model, env_vars, dataset_dict=None):
|
|||||||
|
|
||||||
# define a llmcompressor recipe
|
# define a llmcompressor recipe
|
||||||
recipe = [
|
recipe = [
|
||||||
MODIFIER_DICT[env_vars['modifier']](
|
MODIFIER_DICT[env_vars["modifier"]](
|
||||||
config_groups=SCHEMES_DICT[env_vars['schemes']],
|
config_groups=SCHEMES_DICT[env_vars["schemes"]],
|
||||||
ignore=ignore,
|
ignore=ignore,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -138,18 +130,16 @@ def save_quantized_model(model, tokenizer, save_path, save_compressed=False):
|
|||||||
tokenizer.save_pretrained(save_path)
|
tokenizer.save_pretrained(save_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# get environment variables
|
# get environment variables
|
||||||
env_vars = load_environment_variables()
|
env_vars = load_environment_variables()
|
||||||
|
|
||||||
# support model type list
|
# support model type list
|
||||||
config = AutoConfig.from_pretrained(env_vars['model_path'], trust_remote_code=True)
|
config = AutoConfig.from_pretrained(env_vars["model_path"], trust_remote_code=True)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
||||||
model = MODEL_DICT[model_type].from_pretrained(
|
model = MODEL_DICT[model_type].from_pretrained(env_vars["model_path"], torch_dtype="auto", trust_remote_code=True)
|
||||||
env_vars['model_path'], torch_dtype="auto", trust_remote_code=True
|
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars["model_path"], trust_remote_code=True)
|
||||||
)
|
|
||||||
tokenizer = TOKENIZER_DICT[model_type].from_pretrained(env_vars['model_path'], trust_remote_code=True)
|
|
||||||
|
|
||||||
ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer)
|
ds = load_calibration_text_dataset(env_vars["calib_prompt_path"], tokenizer)
|
||||||
|
|
||||||
@@ -157,4 +147,4 @@ if __name__ == '__main__':
|
|||||||
quantize_model(model, env_vars, ds)
|
quantize_model(model, env_vars, ds)
|
||||||
|
|
||||||
# save the quantized model
|
# save the quantized model
|
||||||
save_quantized_model(model, tokenizer, env_vars['export_path'], True)
|
save_quantized_model(model, tokenizer, env_vars["export_path"], True)
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
||||||
|
|
||||||
from llmcompressor import oneshot
|
from llmcompressor import oneshot
|
||||||
from llmcompressor.modifiers.quantization import GPTQModifier
|
from llmcompressor.modifiers.quantization import GPTQModifier
|
||||||
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
from llmcompressor.modifiers.smoothquant import SmoothQuantModifier
|
||||||
from llmcompressor.utils import dispatch_for_generation
|
from llmcompressor.utils import dispatch_for_generation
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
# Select model and load it.
|
# Select model and load it.
|
||||||
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ plugins.md029.enabled = false # ol-prefix
|
|||||||
line-length = 120
|
line-length = 120
|
||||||
# Folder to be modified
|
# Folder to be modified
|
||||||
exclude = [
|
exclude = [
|
||||||
"examples/**",
|
|
||||||
"tests/**",
|
"tests/**",
|
||||||
"vllm_ascend/**",
|
"vllm_ascend/**",
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user