[Lint]Style: Convert example to ruff format (#5863)

### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
SILONG ZENG
2026-01-13 20:46:50 +08:00
committed by GitHub
parent f7b904641e
commit 78d5ce3e01
23 changed files with 678 additions and 1037 deletions

View File

@@ -125,7 +125,7 @@ import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict
from typing import Any
import httpx
from fastapi import FastAPI, Request
@@ -150,22 +150,21 @@ class InstanceType:
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.url = f"http://{host}:{port}/v1"
try:
ip = ipaddress.ip_address(self.host)
if isinstance(ip, ipaddress.IPv6Address):
self.url = f'http://[{host}]:{port}/v1'
self.url = f"http://[{host}]:{port}/v1"
except Exception:
pass
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
@@ -186,16 +185,11 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.node_listener = NodeListener(self)
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
@@ -203,10 +197,8 @@ class ProxyState:
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
@@ -216,23 +208,18 @@ class ProxyState:
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
@@ -240,8 +227,7 @@ class ProxyState:
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
@@ -314,9 +300,7 @@ class ProxyState:
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
async def add_instances(
self, instance_type: str, instances: List[ServerState]
) -> Tuple[List[str], List[str]]:
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
added_nodes, waiting_nodes = [], []
for server in instances:
is_valid = await self.node_listener.check_instance_status(server.client)
@@ -332,7 +316,7 @@ class ProxyState:
waiting_nodes.append(node)
return added_nodes, waiting_nodes
def add_prefillers(self, instances: List[ServerState]) -> None:
def add_prefillers(self, instances: list[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
@@ -341,7 +325,7 @@ class ProxyState:
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
self.print_status(f"Add prefiller instances: {instances}.")
def add_decoders(self, instances: List[ServerState]) -> None:
def add_decoders(self, instances: list[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
@@ -350,7 +334,7 @@ class ProxyState:
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
self.print_status(f"Add decoder instances: {instances}.")
def remove_prefillers(self, instances: List[ServerState]) -> None:
def remove_prefillers(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
@@ -367,7 +351,7 @@ class ProxyState:
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
def remove_decoders(self, instances: List[ServerState]) -> None:
def remove_decoders(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
@@ -387,7 +371,7 @@ class ProxyState:
def print_status(self, msg: str) -> None:
status = {
"prefill_instances": [str(server) for server in self.prefillers],
"decode_instances": [str(server) for server in self.decoders]
"decode_instances": [str(server) for server in self.decoders],
}
print(f"{msg} Status: {status}")
@@ -398,7 +382,7 @@ proxy_state = None
class NodeListener:
def __init__(self, proxy):
self.proxy_state = proxy
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
self.listening_thread.start()
@@ -424,9 +408,7 @@ class NodeListener:
@staticmethod
async def check_instance_status(client: httpx.AsyncClient) -> bool:
endpoint = "/models"
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
try:
response = await client.get(endpoint, headers=headers)
response.raise_for_status()
@@ -439,46 +421,29 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
parser.add_argument("--max-waiting-retries",
type=int,
default=3,
help="Maximum number of retries for waiting nodes to be started")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
parser.add_argument(
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
)
parser.add_argument(
"--waiting-retry-interval",
type=float,
default=10,
help="Check interval (seconds) for waiting nodes to be started")
help="Check interval (seconds) for waiting nodes to be started",
)
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
raise ValueError("Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@@ -486,11 +451,8 @@ def parse_args():
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
yield
for p in proxy_state.prefillers:
await p.client.aclose()
@@ -507,14 +469,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -527,17 +487,18 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
async def send_request_to_service(
client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
@@ -553,46 +514,35 @@ async def send_request_to_service(client: httpx.AsyncClient,
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -601,41 +551,28 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
@@ -648,10 +585,11 @@ async def _handle_select_instance(api: str, req_data: Any,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
base_delay=global_args.retry_delay,
)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
@@ -661,13 +599,15 @@ async def _handle_select_instance(api: str, req_data: Any,
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
return InstanceInfo(
request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score,
)
@dataclass
@@ -686,8 +626,7 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data,
request_length)
instance_info = await _handle_select_instance(api, req_data, request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
@@ -713,34 +652,31 @@ async def _handle_completions(api: str, request: Request):
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
@@ -751,63 +687,52 @@ async def _handle_completions(api: str, request: Request):
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content
stop_reason = choice.get(
"stop_reason")
stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
messages[0]["content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"][
"content"] = generated_token
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
f"the aborted request {instance_info.request_id} will be routing to the target "
"prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -821,20 +746,21 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
if isinstance(instances, str):
instances = [instances]
instances = trans_instances(instances)
all_msg = f"{adjust_mode} {instance_type} instances: " \
f"{[str(server) for server in instances]}."
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
return {"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
return {
"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
}
if adjust_mode == "add":
added_nodes, waiting_nodes = await proxy_state.add_instances(
instance_type, instances
)
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
if waiting_nodes:
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
f"Instances {waiting_nodes} are waiting to be added."
all_msg = (
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
f"Instances {waiting_nodes} are waiting to be added."
)
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
@@ -843,14 +769,14 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
}
except Exception as e:
logger.error(f"Failed to {adjust_mode} instances: {e}")
raise e
def trans_instances(instances: List[str]) -> List[ServerState]:
def trans_instances(instances: list[str]) -> list[ServerState]:
server_list = []
for instance in instances:
h, p = instance.split(":")
@@ -875,7 +801,7 @@ async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
"decode_instances": len(proxy_state.decoders),
}
@@ -889,7 +815,7 @@ async def handle_remove_instances(request: Request):
return await _handle_adjust_instances("remove", request)
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn