1. Fix format check error to make format.sh work 2. Add codespell check CI 3. Add the missing required package for vllm-ascend. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
464 lines
18 KiB
Python
464 lines
18 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
import copy
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
|
|
import aiohttp
|
|
import msgpack # type: ignore
|
|
import zmq
|
|
from quart import Quart, make_response, request
|
|
|
|
DP_PROXY_HTTP_PORT = 10004
|
|
DP_PROXY_ZMQ_REG_PORT = 30006
|
|
DP_PROXY_ZMQ_NOTIFY_PORT = 30005
|
|
|
|
PD_PROXY_ADDRESS = "127.0.0.1:30002"
|
|
|
|
MY_HTTP_ADDRESS = f"127.0.0.1:{DP_PROXY_HTTP_PORT}"
|
|
MY_ZMQ_ADDRESS_PLACEHOLDER = f"127.0.0.1:{DP_PROXY_ZMQ_REG_PORT}"
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TIME_INTERVAL_FOR_IDLE_RUN = 5e-4
|
|
DP_SIZE = 2
|
|
|
|
dp_instances: dict[str, bool] = {}
|
|
dp_cv = threading.Condition()
|
|
round_robin_index = 0
|
|
_idle_send_loop = None
|
|
|
|
|
|
def make_idle_request():
|
|
# Same as before
|
|
data = {
|
|
"prompt": "hi",
|
|
"max_tokens": 1,
|
|
"temperature": 0,
|
|
}
|
|
return data
|
|
|
|
|
|
def random_uuid() -> str:
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
async def send_idle_token_to_client(schedule_dict):
|
|
for key, value in schedule_dict.items():
|
|
if value:
|
|
continue
|
|
request_received_id = random_uuid()
|
|
idle_request_data = make_idle_request()
|
|
forward_request_id = f"dp_idle_{key}_{request_received_id}"
|
|
target_url = f'http://{key}/v1/completions'
|
|
logger.debug(
|
|
f"DP Decode Proxy: Sending idle token to D node {key} at {target_url}"
|
|
)
|
|
generator = forward_request_internal(target_url, idle_request_data,
|
|
forward_request_id)
|
|
try:
|
|
async for response in generator:
|
|
logger.debug(
|
|
f"DP Decode Proxy: Idle Request {request_received_id}: response from {key}, got response: {response}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"DP Decode Proxy: Error sending idle token to {key}: {e}")
|
|
|
|
|
|
def metadata_collect_trigger(poller, router_socket):
|
|
global dp_instances
|
|
global dp_cv
|
|
global _idle_send_loop
|
|
with dp_cv:
|
|
dp_cv.wait()
|
|
while True:
|
|
try:
|
|
schedule_dict = copy.deepcopy(dp_instances)
|
|
for key in schedule_dict.keys():
|
|
schedule_dict[key] = False
|
|
first_start = False
|
|
start_time = None
|
|
while not all(schedule_dict.values()):
|
|
if start_time is not None:
|
|
time_interval = time.time() - start_time
|
|
logger.debug("check time interval: ", time_interval)
|
|
if time_interval > TIME_INTERVAL_FOR_IDLE_RUN:
|
|
logger.debug(
|
|
"exceeds max time interval send idle token to client"
|
|
)
|
|
# Send idle token to client in case of single dp rank run solo and block on the CCL part
|
|
asyncio.run_coroutine_threadsafe(
|
|
send_idle_token_to_client(schedule_dict),
|
|
_idle_send_loop) # type: ignore
|
|
# Note: Reset start time prevent consistently send idle token to client
|
|
# We only reset start time here, for some of the client may loss the idle token send from this proxy
|
|
# and we only exit this while loop when we make sure all the client are exactly start inference in this
|
|
# step
|
|
start_time = time.time()
|
|
socks = dict(poller.poll(timeout=500)) # timeout in 500ms
|
|
if socks:
|
|
logger.debug("receive socks from monitor threads: ", socks)
|
|
if router_socket in socks:
|
|
messages = router_socket.recv_multipart()
|
|
try:
|
|
# {"info": "notify_step", "http_address": ""}
|
|
for message in messages:
|
|
data = msgpack.loads(message)
|
|
http_addr = None
|
|
logger.debug(f"receive message {data}")
|
|
if data.get("info") == "notify_step":
|
|
http_addr = data.get("http_address")
|
|
if http_addr in schedule_dict.keys():
|
|
schedule_dict[http_addr] = True
|
|
logger.debug("set first time")
|
|
if not first_start:
|
|
logger.debug("record start time")
|
|
first_start = True
|
|
start_time = time.time()
|
|
else:
|
|
logger.warning("Unrecognize http address")
|
|
else:
|
|
logger.warning(
|
|
"Got unrecognize info type! We only accept notify step info yet"
|
|
)
|
|
except (msgpack.UnpackException, TypeError, KeyError) as e:
|
|
logger.error(
|
|
f"Error processing message from {http_addr}: {e}. Message: {data}"
|
|
)
|
|
except zmq.ZMQError as e: # type: ignore
|
|
logger.error(f"ZMQ Error in monitor thread: {e}")
|
|
if e.errno == zmq.ETERM: # type: ignore
|
|
logger.error(
|
|
"Monitor thread terminating due to context termination.")
|
|
break
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in monitor thread: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
time.sleep(1)
|
|
|
|
|
|
def _listen_for_d_register(poller, router_socket):
|
|
global dp_instances
|
|
global dp_cv
|
|
global DP_SIZE
|
|
logger.info(
|
|
f"DP Decode Proxy: D Node ZMQ Listener started on ROUTER port {DP_PROXY_ZMQ_REG_PORT}"
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
socks = dict(poller.poll(timeout=1000))
|
|
if router_socket in socks:
|
|
remote_id, message = router_socket.recv_multipart()
|
|
try:
|
|
data = msgpack.loads(message)
|
|
if data.get("type") == "DP":
|
|
http_addr = data.get("http_address")
|
|
zmq_addr = data.get("zmq_address")
|
|
if http_addr:
|
|
with dp_cv:
|
|
if http_addr not in dp_instances:
|
|
logger.info(
|
|
f"DP Decode Proxy: Registering D Node instance: http={http_addr}, zmq={zmq_addr}"
|
|
)
|
|
dp_instances[http_addr] = True
|
|
if len(dp_instances) >= DP_SIZE:
|
|
logger.info(
|
|
f"DP Decode Proxy: Reached expected D Node count ({DP_SIZE}). Notifying metadata collector."
|
|
)
|
|
dp_cv.notify_all()
|
|
else:
|
|
pass
|
|
else:
|
|
logger.warning(
|
|
f"DP Decode Proxy: Received D Node registration from {remote_id.decode()} without http_address. Data: {data}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"DP Decode Proxy: Received message with unexpected type from {remote_id.decode()}. Type: {data.get('type')}, Data: {data}"
|
|
)
|
|
|
|
except (msgpack.UnpackException, TypeError, KeyError) as e:
|
|
logger.error(
|
|
f"DP Decode Proxy: Error processing D Node registration from {remote_id.decode()}: {e}. Message: {message}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"DP Decode Proxy: Unexpected error processing D Node registration from {remote_id.decode()}: {e}"
|
|
)
|
|
|
|
except zmq.ZMQError as e: # type: ignore
|
|
logger.error(
|
|
f"DP Decode Proxy: ZMQ Error in D Node listener thread: {e}")
|
|
if e.errno == zmq.ETERM: # type: ignore
|
|
logger.info(
|
|
"DP Decode Proxy: D Node Listener thread terminating.")
|
|
break
|
|
time.sleep(1)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"DP Decode Proxy: Unexpected error in D Node listener thread: {e}"
|
|
)
|
|
import traceback
|
|
traceback.print_exc()
|
|
time.sleep(1)
|
|
|
|
|
|
def _register_to_pd_proxy(pd_proxy_zmq_addr, my_http_addr, my_zmq_addr):
|
|
context = None
|
|
sock = None
|
|
while True:
|
|
try:
|
|
if context is None:
|
|
context = zmq.Context() # type: ignore
|
|
if sock is None:
|
|
sock = context.socket(zmq.DEALER) # type: ignore
|
|
identity = f"dp_proxy_{my_http_addr}".encode('utf-8')
|
|
sock.setsockopt(zmq.IDENTITY, identity) # type: ignore
|
|
sock.setsockopt(zmq.LINGER, 0) # type: ignore
|
|
logger.info(
|
|
f"DP Decode Proxy: Attempting to connect to PD Proxy at {pd_proxy_zmq_addr}..."
|
|
)
|
|
sock.connect(f"tcp://{pd_proxy_zmq_addr}")
|
|
logger.info(
|
|
f"DP Decode Proxy: Connected to PD Proxy at {pd_proxy_zmq_addr}."
|
|
)
|
|
|
|
data = {
|
|
"type": "D",
|
|
"http_address": my_http_addr,
|
|
"zmq_address": my_zmq_addr
|
|
}
|
|
logger.debug(
|
|
f"DP Decode Proxy: Sending registration/heartbeat to PD Proxy: {data}"
|
|
)
|
|
sock.send(msgpack.dumps(data))
|
|
time.sleep(5)
|
|
|
|
except zmq.ZMQError as e: # type: ignore
|
|
logger.error(
|
|
f"DP Decode Proxy: ZMQ Error connecting/sending to PD Proxy ({pd_proxy_zmq_addr}): {e}"
|
|
)
|
|
if sock:
|
|
sock.close()
|
|
sock = None
|
|
time.sleep(10)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"DP Decode Proxy: Unexpected error in PD Proxy registration thread: {e}"
|
|
)
|
|
import traceback
|
|
traceback.print_exc()
|
|
if sock:
|
|
sock.close()
|
|
sock = None
|
|
time.sleep(10)
|
|
finally:
|
|
pass
|
|
|
|
|
|
def start_zmq_thread(hostname, port, socket_type, target_func, thread_name):
|
|
"""Generic ZMQ thread starter for ROUTER or PULL."""
|
|
if not hostname:
|
|
hostname = "0.0.0.0"
|
|
context = zmq.Context.instance() # type: ignore
|
|
socket = context.socket(socket_type)
|
|
socket.setsockopt(zmq.LINGER, 0) # type: ignore
|
|
try:
|
|
socket.bind(f"tcp://{hostname}:{port}")
|
|
except zmq.ZMQError as e: # type: ignore
|
|
logger.error(
|
|
f"DP Decode Proxy: Error binding ZMQ {socket_type} socket to tcp://{hostname}:{port}: {e}"
|
|
)
|
|
socket.close()
|
|
raise
|
|
|
|
poller = zmq.Poller() # type: ignore
|
|
poller.register(socket, zmq.POLLIN) # type: ignore
|
|
|
|
thread = threading.Thread(target=target_func,
|
|
args=(poller, socket),
|
|
daemon=True,
|
|
name=thread_name)
|
|
thread.start()
|
|
return thread, socket
|
|
|
|
|
|
def start_thread_with_event_loop():
|
|
global _idle_send_loop
|
|
asyncio.set_event_loop(_idle_send_loop)
|
|
_idle_send_loop.run_forever() # type: ignore
|
|
|
|
|
|
async def forward_request_internal(url, data, request_id):
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
headers = {
|
|
"Authorization":
|
|
f"Bearer {os.environ.get('OPENAI_API_KEY', '')}",
|
|
"X-Request-Id": request_id,
|
|
"Content-Type": "application/json"
|
|
}
|
|
async with session.post(url=url, json=data,
|
|
headers=headers) as response:
|
|
if response.status == 200:
|
|
async for chunk_bytes in response.content.iter_chunked(
|
|
1024):
|
|
yield chunk_bytes
|
|
else:
|
|
error_content = await response.read()
|
|
logger.warning(
|
|
f"DP Decode Proxy: Error from D node {url} (status {response.status}): {error_content.decode(errors='ignore')}"
|
|
)
|
|
yield error_content
|
|
|
|
except aiohttp.ClientError as e:
|
|
logger.warning(
|
|
f"DP Decode Proxy: Error forwarding request {request_id} to D node {url}: {e}"
|
|
)
|
|
error_msg = f"Failed to connect or communicate with D node at {url}: {e}".encode(
|
|
'utf-8')
|
|
yield error_msg
|
|
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
|
app = Quart(__name__)
|
|
|
|
|
|
@app.route('/v1/completions', methods=['POST'])
|
|
async def handle_request():
|
|
global dp_instances
|
|
global dp_cv
|
|
global round_robin_index
|
|
|
|
request_received_id = request.headers.get("X-Request-Id")
|
|
if not request_received_id:
|
|
fallback_id = f"dp_fallback_{random_uuid()}"
|
|
logger.warning(
|
|
f"DP Decode Proxy: Received request without X-Request-Id header. Using fallback ID: {fallback_id}"
|
|
)
|
|
request_received_id = fallback_id
|
|
else:
|
|
logger.info(
|
|
f"DP Decode Proxy: Received request from PD Proxy, using propagated ID: {request_received_id}"
|
|
)
|
|
|
|
try:
|
|
original_request_data = await request.get_json()
|
|
if not original_request_data:
|
|
return await make_response("Request body must be valid JSON.", 400)
|
|
|
|
target_addr = None
|
|
with dp_cv:
|
|
if not dp_instances:
|
|
logger.warning(
|
|
f"DP Decode Proxy: Request {request_received_id}: No D Node instances available/registered."
|
|
)
|
|
return await make_response("No Decode instances available.",
|
|
503)
|
|
|
|
dp_addresses = list(dp_instances.keys())
|
|
if not dp_addresses:
|
|
logger.error(
|
|
f"DP Decode Proxy: Request {request_received_id}: Internal error - dp_instances populated but list is empty."
|
|
)
|
|
return await make_response("Internal Server Error", 500)
|
|
|
|
current_selection_index = round_robin_index % len(dp_addresses)
|
|
target_addr = dp_addresses[current_selection_index]
|
|
round_robin_index += 1
|
|
|
|
logger.info(
|
|
f"DP Decode Proxy: Request {request_received_id}: Routing Decode to D Node {target_addr} (Index {current_selection_index})"
|
|
)
|
|
|
|
target_url = f'http://{target_addr}/v1/completions'
|
|
|
|
generator = forward_request_internal(target_url, original_request_data,
|
|
request_received_id)
|
|
|
|
response = await make_response(generator)
|
|
response.timeout = None
|
|
|
|
if original_request_data.get("stream", False):
|
|
response.headers['Content-Type'] = 'text/event-stream'
|
|
response.headers['Cache-Control'] = 'no-cache'
|
|
else:
|
|
response.headers['Content-Type'] = 'application/json'
|
|
|
|
logger.debug(
|
|
f"DP Decode Proxy: Request {request_received_id}: Streaming response from D node {target_addr}"
|
|
)
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"DP Decode Proxy: Error handling request {request_received_id}: {e}"
|
|
)
|
|
return await make_response("Internal Server Error", 500)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
d_listener_thread, d_reg_socket = start_zmq_thread(
|
|
"0.0.0.0",
|
|
DP_PROXY_ZMQ_REG_PORT,
|
|
zmq.ROUTER, # type: ignore
|
|
_listen_for_d_register, # type: ignore
|
|
"DP_DNodeListenerThread")
|
|
|
|
metadata_thread, notify_socket = start_zmq_thread(
|
|
"0.0.0.0",
|
|
DP_PROXY_ZMQ_NOTIFY_PORT,
|
|
zmq.PULL, # type: ignore
|
|
metadata_collect_trigger,
|
|
"DP_MetadataMonitorThread")
|
|
|
|
_idle_send_loop = asyncio.new_event_loop()
|
|
idle_loop_thread = threading.Thread(target=start_thread_with_event_loop,
|
|
daemon=True,
|
|
name="DP_IdleSendLoopThread")
|
|
idle_loop_thread.start()
|
|
|
|
pd_register_thread = threading.Thread(target=_register_to_pd_proxy,
|
|
args=(PD_PROXY_ADDRESS,
|
|
MY_HTTP_ADDRESS,
|
|
MY_ZMQ_ADDRESS_PLACEHOLDER),
|
|
daemon=True,
|
|
name="DP_PDRegisterThread")
|
|
pd_register_thread.start()
|
|
|
|
logger.info(
|
|
f"DP Decode Proxy: Starting Quart web server on http://0.0.0.0:{DP_PROXY_HTTP_PORT}"
|
|
)
|
|
zmq_context = zmq.Context.instance() # type: ignore
|
|
try:
|
|
app.run(host='0.0.0.0', port=DP_PROXY_HTTP_PORT)
|
|
except KeyboardInterrupt:
|
|
logger.info("DP Decode Proxy: KeyboardInterrupt received, stopping...")
|
|
except Exception as e:
|
|
logger.error(f"DP Decode Proxy: Failed to run Quart server: {e}")
|
|
finally:
|
|
logger.info("DP Decode Proxy: Shutting down...")
|
|
if _idle_send_loop and _idle_send_loop.is_running():
|
|
logger.info("DP Decode Proxy: Stopping idle send loop...")
|
|
_idle_send_loop.call_soon_threadsafe(_idle_send_loop.stop)
|
|
|
|
if d_reg_socket:
|
|
d_reg_socket.close()
|
|
if notify_socket:
|
|
notify_socket.close()
|
|
if zmq_context:
|
|
zmq_context.term()
|
|
|
|
logger.info("DP Decode Proxy: Shutdown complete.")
|