[Lint]Style: Convert example to ruff format (#5863)
### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
@@ -125,7 +125,7 @@ import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Tuple, Dict
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -150,22 +150,21 @@ class InstanceType:
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f'http://{host}:{port}/v1'
|
||||
self.url = f"http://{host}:{port}/v1"
|
||||
try:
|
||||
ip = ipaddress.ip_address(self.host)
|
||||
if isinstance(ip, ipaddress.IPv6Address):
|
||||
self.url = f'http://[{host}]:{port}/v1'
|
||||
self.url = f"http://[{host}]:{port}/v1"
|
||||
except Exception:
|
||||
pass
|
||||
self.client = httpx.AsyncClient(timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=100000,
|
||||
max_keepalive_connections=100000))
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
||||
)
|
||||
self.active_tokens = 0
|
||||
self.active_kv_cache = 0 # Only for prefiller
|
||||
self.active_requests = 0 # Number of active requests
|
||||
@@ -186,16 +185,11 @@ class ServerState:
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.node_listener = NodeListener(self)
|
||||
|
||||
self.prefillers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in prefiller_instances
|
||||
]
|
||||
self.decoders: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in decoder_instances
|
||||
]
|
||||
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
||||
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
|
||||
self.req_to_prefiller = {}
|
||||
self.req_id_lock = asyncio.Lock()
|
||||
# Removed selection locks - no longer needed for synchronous methods
|
||||
@@ -203,10 +197,8 @@ class ProxyState:
|
||||
# Initialize priority queues for efficient server selection
|
||||
# Each entry is (priority_score, server_index, server_reference)
|
||||
# Lower priority score = higher priority (less loaded)
|
||||
self.prefiller_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.prefillers)]
|
||||
self.decoder_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.decoders)]
|
||||
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
|
||||
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
heapq.heapify(self.decoder_heap)
|
||||
|
||||
@@ -216,23 +208,18 @@ class ProxyState:
|
||||
# Priority based on active_tokens and active_kv_cache
|
||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||
# Remove old entry and add new one
|
||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.prefiller_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
|
||||
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
|
||||
|
||||
def _update_decoder_priority(self, server_idx: int):
|
||||
"""Update the priority of a decoder server in the heap."""
|
||||
server = self.decoders[server_idx]
|
||||
priority = server.active_tokens
|
||||
# Remove old entry and add new one
|
||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.decoder_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
|
||||
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
|
||||
|
||||
def abort_prefiller_request(self, server_idx: int,
|
||||
request_id): # Changed to synchronous
|
||||
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
|
||||
"""
|
||||
Mark a request as aborted. This will helps to release kv cache in
|
||||
prefiller node.
|
||||
@@ -240,8 +227,7 @@ class ProxyState:
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||
|
||||
def aquire_aborted_prefiller_requests(
|
||||
self, server_idx: int): # Changed to synchronous
|
||||
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||
"""
|
||||
Get the set of aborted requests and clear it.
|
||||
This is used to release kv cache in prefiller node.
|
||||
@@ -314,9 +300,7 @@ class ProxyState:
|
||||
def calculate_decode_scores(self, request_length: int) -> float:
|
||||
return request_length
|
||||
|
||||
async def add_instances(
|
||||
self, instance_type: str, instances: List[ServerState]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
|
||||
added_nodes, waiting_nodes = [], []
|
||||
for server in instances:
|
||||
is_valid = await self.node_listener.check_instance_status(server.client)
|
||||
@@ -332,7 +316,7 @@ class ProxyState:
|
||||
waiting_nodes.append(node)
|
||||
return added_nodes, waiting_nodes
|
||||
|
||||
def add_prefillers(self, instances: List[ServerState]) -> None:
|
||||
def add_prefillers(self, instances: list[ServerState]) -> None:
|
||||
num_prefillers = len(self.prefillers)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.prefillers:
|
||||
@@ -341,7 +325,7 @@ class ProxyState:
|
||||
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
||||
self.print_status(f"Add prefiller instances: {instances}.")
|
||||
|
||||
def add_decoders(self, instances: List[ServerState]) -> None:
|
||||
def add_decoders(self, instances: list[ServerState]) -> None:
|
||||
num_decoders = len(self.decoders)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.decoders:
|
||||
@@ -350,7 +334,7 @@ class ProxyState:
|
||||
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
||||
self.print_status(f"Add decoder instances: {instances}.")
|
||||
|
||||
def remove_prefillers(self, instances: List[ServerState]) -> None:
|
||||
def remove_prefillers(self, instances: list[ServerState]) -> None:
|
||||
instances_to_remove = set(instances)
|
||||
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||
@@ -367,7 +351,7 @@ class ProxyState:
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||
|
||||
def remove_decoders(self, instances: List[ServerState]) -> None:
|
||||
def remove_decoders(self, instances: list[ServerState]) -> None:
|
||||
instances_to_remove = set(instances)
|
||||
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||
decoder_heap_copy = self.decoder_heap.copy()
|
||||
@@ -387,7 +371,7 @@ class ProxyState:
|
||||
def print_status(self, msg: str) -> None:
|
||||
status = {
|
||||
"prefill_instances": [str(server) for server in self.prefillers],
|
||||
"decode_instances": [str(server) for server in self.decoders]
|
||||
"decode_instances": [str(server) for server in self.decoders],
|
||||
}
|
||||
print(f"{msg} Status: {status}")
|
||||
|
||||
@@ -398,7 +382,7 @@ proxy_state = None
|
||||
class NodeListener:
|
||||
def __init__(self, proxy):
|
||||
self.proxy_state = proxy
|
||||
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
|
||||
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
|
||||
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
||||
self.listening_thread.start()
|
||||
|
||||
@@ -424,9 +408,7 @@ class NodeListener:
|
||||
@staticmethod
|
||||
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
||||
endpoint = "/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
try:
|
||||
response = await client.get(endpoint, headers=headers)
|
||||
response.raise_for_status()
|
||||
@@ -439,46 +421,29 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8001])
|
||||
parser.add_argument("--decoder-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
|
||||
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
|
||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
parser.add_argument("--max-waiting-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for waiting nodes to be started")
|
||||
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--waiting-retry-interval",
|
||||
type=float,
|
||||
default=10,
|
||||
help="Check interval (seconds) for waiting nodes to be started")
|
||||
help="Check interval (seconds) for waiting nodes to be started",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
||||
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
return args
|
||||
|
||||
@@ -486,11 +451,8 @@ def parse_args():
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global proxy_state
|
||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
||||
global_args.decoder_instances)
|
||||
print(
|
||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
||||
)
|
||||
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
|
||||
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
|
||||
yield
|
||||
for p in proxy_state.prefillers:
|
||||
await p.client.aclose()
|
||||
@@ -507,14 +469,12 @@ async def listen_for_disconnect(request: Request) -> None:
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs["request"]
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if handler_task in done:
|
||||
@@ -527,17 +487,18 @@ def with_cancellation(handler_func):
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
async def send_request_to_service(client: httpx.AsyncClient,
|
||||
prefiller_id: int,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||
prefiller_id)
|
||||
async def send_request_to_service(
|
||||
client: httpx.AsyncClient,
|
||||
prefiller_id: int,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2,
|
||||
):
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
@@ -553,46 +514,35 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
||||
req_data["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||
last_exc = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = await client.post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response = await client.post(endpoint, json=req_data, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||
last_exc = e
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for {endpoint}.")
|
||||
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
|
||||
raise last_exc
|
||||
|
||||
|
||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
async def stream_service_response_with_retry(
|
||||
client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2,
|
||||
):
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with client.stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
first_chunk_sent = False
|
||||
async for chunk in response.aiter_bytes():
|
||||
@@ -601,41 +551,28 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
return # Success, exit after streaming
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||
raise e
|
||||
except Exception as e:
|
||||
# If any chunk has been sent, do not retry, just log and drop
|
||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
||||
logger.error(
|
||||
f"Streaming to client interrupted after response started: {str(e)}"
|
||||
)
|
||||
if "first_chunk_sent" in locals() and first_chunk_sent:
|
||||
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
||||
return
|
||||
else:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||
raise e
|
||||
|
||||
|
||||
async def _handle_select_instance(api: str, req_data: Any,
|
||||
request_length: int):
|
||||
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
@@ -648,10 +585,11 @@ async def _handle_select_instance(api: str, req_data: Any,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
base_delay=global_args.retry_delay,
|
||||
)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
kv_transfer_params = response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
# Select decoder
|
||||
@@ -661,13 +599,15 @@ async def _handle_select_instance(api: str, req_data: Any,
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
return InstanceInfo(request_id=request_id,
|
||||
prefiller_idx=prefiller_idx,
|
||||
prefiller_score=prefiller_score,
|
||||
prefiller=prefiller,
|
||||
decoder=decoder,
|
||||
decoder_idx=decoder_idx,
|
||||
decoder_score=decoder_score)
|
||||
return InstanceInfo(
|
||||
request_id=request_id,
|
||||
prefiller_idx=prefiller_idx,
|
||||
prefiller_score=prefiller_score,
|
||||
prefiller=prefiller,
|
||||
decoder=decoder,
|
||||
decoder_idx=decoder_idx,
|
||||
decoder_score=decoder_score,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -686,8 +626,7 @@ async def _handle_completions(api: str, request: Request):
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
instance_info = await _handle_select_instance(api, req_data,
|
||||
request_length)
|
||||
instance_info = await _handle_select_instance(api, req_data, request_length)
|
||||
stream_flag = bool(req_data.get("stream", False))
|
||||
chat_flag = "messages" in req_data
|
||||
|
||||
@@ -713,34 +652,31 @@ async def _handle_completions(api: str, request: Request):
|
||||
while retry:
|
||||
retry = False
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
instance_info.decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
instance_info.decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay,
|
||||
):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
||||
released_kv = True
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
logger.debug(
|
||||
f"Skipping chunk: {chunk}")
|
||||
logger.debug(f"Skipping chunk: {chunk}")
|
||||
yield chunk
|
||||
continue
|
||||
if not chunk_str:
|
||||
continue
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[len("data: "):]
|
||||
chunk_str = chunk_str[len("data: ") :]
|
||||
try:
|
||||
chunk_json = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
# if chunk is [done], skip it.
|
||||
logger.debug(
|
||||
f"Skipping chunk: {chunk_str}")
|
||||
logger.debug(f"Skipping chunk: {chunk_str}")
|
||||
yield chunk
|
||||
continue
|
||||
choices = chunk_json.get("choices", [])
|
||||
@@ -751,63 +687,52 @@ async def _handle_completions(api: str, request: Request):
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta") or {}
|
||||
message = choice.get("message") or {}
|
||||
content = (
|
||||
delta.get("content")
|
||||
or message.get("content")
|
||||
or choice.get("text")
|
||||
or ""
|
||||
)
|
||||
content = delta.get("content") or message.get("content") or choice.get("text") or ""
|
||||
generated_token += content
|
||||
|
||||
stop_reason = choice.get(
|
||||
"stop_reason")
|
||||
stop_reason = choice.get("stop_reason")
|
||||
usage = chunk_json.get("usage", {})
|
||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
||||
(completion_tokens + usage.get("completion_tokens"))
|
||||
completion_tokens = (
|
||||
(completion_tokens + 1)
|
||||
if stream_flag
|
||||
else (completion_tokens + usage.get("completion_tokens"))
|
||||
)
|
||||
if stop_reason == "recomputed":
|
||||
retry = True
|
||||
retry_count += 1
|
||||
if chat_flag:
|
||||
messages[0][
|
||||
"content"] = origin_prompt + generated_token
|
||||
messages[0]["content"] = origin_prompt + generated_token
|
||||
else:
|
||||
req_data[
|
||||
"prompt"] = origin_prompt + generated_token
|
||||
req_data[
|
||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
tmp_request_length = len(
|
||||
json.dumps(req_data).encode("utf-8"))
|
||||
instance_info = await _handle_select_instance(
|
||||
api, req_data, tmp_request_length)
|
||||
req_data["prompt"] = origin_prompt + generated_token
|
||||
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
|
||||
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
|
||||
break
|
||||
if retry_count > 0 and not stream_flag:
|
||||
if chat_flag:
|
||||
choice["message"][
|
||||
"content"] = generated_token
|
||||
choice["message"]["content"] = generated_token
|
||||
else:
|
||||
choice["text"] = generated_token
|
||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
|
||||
f"the aborted request {instance_info.request_id} will be routing to the target "
|
||||
"prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(
|
||||
instance_info.prefiller_idx, instance_info.request_id)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
||||
instance_info.decoder_score)
|
||||
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
@@ -821,20 +746,21 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||
if isinstance(instances, str):
|
||||
instances = [instances]
|
||||
instances = trans_instances(instances)
|
||||
all_msg = f"{adjust_mode} {instance_type} instances: " \
|
||||
f"{[str(server) for server in instances]}."
|
||||
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
|
||||
|
||||
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
||||
return {"error": f"Instance type {instance_type} is not supported. "
|
||||
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
|
||||
return {
|
||||
"error": f"Instance type {instance_type} is not supported. "
|
||||
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
|
||||
}
|
||||
|
||||
if adjust_mode == "add":
|
||||
added_nodes, waiting_nodes = await proxy_state.add_instances(
|
||||
instance_type, instances
|
||||
)
|
||||
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
|
||||
if waiting_nodes:
|
||||
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
|
||||
f"Instances {waiting_nodes} are waiting to be added."
|
||||
all_msg = (
|
||||
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
|
||||
f"Instances {waiting_nodes} are waiting to be added."
|
||||
)
|
||||
elif adjust_mode == "remove":
|
||||
if instance_type == InstanceType.PREFILL:
|
||||
proxy_state.remove_prefillers(instances)
|
||||
@@ -843,14 +769,14 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||
return {
|
||||
"message": all_msg,
|
||||
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
|
||||
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def trans_instances(instances: List[str]) -> List[ServerState]:
|
||||
def trans_instances(instances: list[str]) -> list[ServerState]:
|
||||
server_list = []
|
||||
for instance in instances:
|
||||
h, p = instance.split(":")
|
||||
@@ -875,7 +801,7 @@ async def healthcheck():
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(proxy_state.prefillers),
|
||||
"decode_instances": len(proxy_state.decoders)
|
||||
"decode_instances": len(proxy_state.decoders),
|
||||
}
|
||||
|
||||
|
||||
@@ -889,7 +815,7 @@ async def handle_remove_instances(request: Request):
|
||||
return await _handle_adjust_instances("remove", request)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
|
||||
Reference in New Issue
Block a user