[Misc]Remove PD v0 code (#2047)
Cleanup V0 disaggregated prefill code for V0 Engine.
part of https://github.com/vllm-project/vllm-ascend/issues/1620
TODO: enable v1 e2e test.
- vLLM version: v0.10.0
- vLLM main:
2cc571199b
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
This file demonstrates the example usage of disaggregated prefilling
|
||||
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
|
||||
and then transfer the KV cache between them.
|
||||
prompy_device_ips denotes device ip of NPU 0,1
|
||||
decode_device_ips denotes device ip of NPU 2,3
|
||||
The device ips of all NPUs in current server can be found through
|
||||
examples/disaggregated_prefill/find_device_ips.py
|
||||
"""
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["1.2.3.1", "1.2.3.2"],
|
||||
"decode_device_ips": ["1.2.3.9", "1.2.3.10"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
}
|
||||
|
||||
|
||||
def clean_up():
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import (
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def run_prefill(prefill_done, process_close):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?", "Hi, what is your name?",
|
||||
"Tell me a very long story.", "what is your favourite book?"
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
|
||||
)
|
||||
global kv_connector_extra_config
|
||||
ktc.kv_connector_extra_config = kv_connector_extra_config
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
llm.generate(prompts, sampling_params)
|
||||
print("Prefill node is finished.")
|
||||
prefill_done.set()
|
||||
|
||||
# To keep the prefill node running in case the decode node is not done;
|
||||
# otherwise, the script might exit prematurely, causing incomplete decoding.
|
||||
try:
|
||||
while not process_close.is_set():
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Script stopped by user.")
|
||||
finally:
|
||||
print("Cleanup prefill resources")
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
def run_decode(prefill_done):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?",
|
||||
"Hi, what is your name?",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendSimpleConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
|
||||
)
|
||||
global kv_connector_extra_config
|
||||
ktc.kv_connector_extra_config = kv_connector_extra_config
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
# Wait for the producer to start the consumer
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
|
||||
# At this point when the prefill_done is set, the kv-cache should have been
|
||||
# transferred to this decode node, so we can start decoding.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.get_context('spawn')
|
||||
|
||||
prefill_done = Event()
|
||||
process_close = Event()
|
||||
prefill_process = Process(target=run_prefill,
|
||||
args=(
|
||||
prefill_done,
|
||||
process_close,
|
||||
))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
||||
# Start decode node
|
||||
decode_process.start()
|
||||
|
||||
# Terminate the prefill node when decode is finished
|
||||
decode_process.join()
|
||||
|
||||
# Terminate prefill process
|
||||
process_close.set()
|
||||
prefill_process.join()
|
||||
prefill_process.terminate()
|
||||
print("All process done!")
|
||||
@@ -1,466 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack # type: ignore
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
DP_PROXY_HTTP_PORT = 10004
|
||||
DP_PROXY_ZMQ_REG_PORT = 30006
|
||||
DP_PROXY_ZMQ_NOTIFY_PORT = 30005
|
||||
|
||||
PD_PROXY_ADDRESS = "127.0.0.1:30002"
|
||||
|
||||
MY_HTTP_ADDRESS = f"127.0.0.1:{DP_PROXY_HTTP_PORT}"
|
||||
MY_ZMQ_ADDRESS_PLACEHOLDER = f"127.0.0.1:{DP_PROXY_ZMQ_REG_PORT}"
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TIME_INTERVAL_FOR_IDLE_RUN = 5e-4
|
||||
DP_SIZE = 2
|
||||
|
||||
dp_instances: dict[str, bool] = {}
|
||||
dp_cv = threading.Condition()
|
||||
round_robin_index = 0
|
||||
_idle_send_loop = None
|
||||
|
||||
|
||||
def make_idle_request():
|
||||
# Same as before
|
||||
data = {
|
||||
"prompt": "hi",
|
||||
"max_tokens": 1,
|
||||
"temperature": 0,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
async def send_idle_token_to_client(schedule_dict):
|
||||
for key, value in schedule_dict.items():
|
||||
if value:
|
||||
continue
|
||||
request_received_id = random_uuid()
|
||||
idle_request_data = make_idle_request()
|
||||
forward_request_id = f"dp_idle_{key}_{request_received_id}"
|
||||
target_url = f'http://{key}/v1/completions'
|
||||
logger.debug(
|
||||
f"DP Decode Proxy: Sending idle token to D node {key} at {target_url}"
|
||||
)
|
||||
generator = forward_request_internal(target_url, idle_request_data,
|
||||
forward_request_id)
|
||||
try:
|
||||
async for response in generator:
|
||||
logger.debug(
|
||||
f"DP Decode Proxy: Idle Request {request_received_id}: response from {key}, got response: {response}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Error sending idle token to {key}: {e}")
|
||||
|
||||
|
||||
def metadata_collect_trigger(poller, router_socket):
|
||||
global dp_instances
|
||||
global dp_cv
|
||||
global _idle_send_loop
|
||||
with dp_cv:
|
||||
dp_cv.wait()
|
||||
while True:
|
||||
try:
|
||||
schedule_dict = copy.deepcopy(dp_instances)
|
||||
for key in schedule_dict.keys():
|
||||
schedule_dict[key] = False
|
||||
first_start = False
|
||||
start_time = None
|
||||
while not all(schedule_dict.values()):
|
||||
if start_time is not None:
|
||||
time_interval = time.time() - start_time
|
||||
logger.debug("check time interval: ", time_interval)
|
||||
if time_interval > TIME_INTERVAL_FOR_IDLE_RUN:
|
||||
logger.debug(
|
||||
"exceeds max time interval send idle token to client"
|
||||
)
|
||||
# Send idle token to client in case of single dp rank run solo and block on the CCL part
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
send_idle_token_to_client(schedule_dict),
|
||||
_idle_send_loop) # type: ignore
|
||||
# Note: Reset start time prevent consistently send idle token to client
|
||||
# We only reset start time here, for some of the client may loss the idle token send from this proxy
|
||||
# and we only exit this while loop when we make sure all the client are exactly start inference in this
|
||||
# step
|
||||
start_time = time.time()
|
||||
socks = dict(poller.poll(timeout=500)) # timeout in 500ms
|
||||
if socks:
|
||||
logger.debug("receive socks from monitor threads: ", socks)
|
||||
if router_socket in socks:
|
||||
messages = router_socket.recv_multipart()
|
||||
try:
|
||||
# {"info": "notify_step", "http_address": ""}
|
||||
for message in messages:
|
||||
data = msgpack.loads(message)
|
||||
http_addr = None
|
||||
logger.debug(f"receive message {data}")
|
||||
if data.get("info") == "notify_step":
|
||||
http_addr = data.get("http_address")
|
||||
if http_addr in schedule_dict.keys():
|
||||
schedule_dict[http_addr] = True
|
||||
logger.debug("set first time")
|
||||
if not first_start:
|
||||
logger.debug("record start time")
|
||||
first_start = True
|
||||
start_time = time.time()
|
||||
else:
|
||||
logger.warning("Unrecognize http address")
|
||||
else:
|
||||
logger.warning(
|
||||
"Got unrecognize info type! We only accept notify step info yet"
|
||||
)
|
||||
except (msgpack.UnpackException, TypeError, KeyError) as e:
|
||||
logger.error(
|
||||
f"Error processing message from {http_addr}: {e}. Message: {data}"
|
||||
)
|
||||
except zmq.ZMQError as e: # type: ignore
|
||||
logger.error(f"ZMQ Error in monitor thread: {e}")
|
||||
if e.errno == zmq.ETERM: # type: ignore
|
||||
logger.error(
|
||||
"Monitor thread terminating due to context termination.")
|
||||
break
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in monitor thread: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _listen_for_d_register(poller, router_socket):
|
||||
global dp_instances
|
||||
global dp_cv
|
||||
global DP_SIZE
|
||||
logger.info(
|
||||
f"DP Decode Proxy: D Node ZMQ Listener started on ROUTER port {DP_PROXY_ZMQ_REG_PORT}"
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
socks = dict(poller.poll(timeout=1000))
|
||||
if router_socket in socks:
|
||||
remote_id, message = router_socket.recv_multipart()
|
||||
try:
|
||||
data = msgpack.loads(message)
|
||||
if data.get("type") == "DP":
|
||||
http_addr = data.get("http_address")
|
||||
zmq_addr = data.get("zmq_address")
|
||||
if http_addr:
|
||||
with dp_cv:
|
||||
if http_addr not in dp_instances:
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Registering D Node instance: http={http_addr}, zmq={zmq_addr}"
|
||||
)
|
||||
dp_instances[http_addr] = True
|
||||
if len(dp_instances) >= DP_SIZE:
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Reached expected D Node count ({DP_SIZE}). Notifying metadata collector."
|
||||
)
|
||||
dp_cv.notify_all()
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Received D Node registration from {remote_id.decode()} without http_address. Data: {data}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Received message with unexpected type from {remote_id.decode()}. Type: {data.get('type')}, Data: {data}"
|
||||
)
|
||||
|
||||
except (msgpack.UnpackException, TypeError, KeyError) as e:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Error processing D Node registration from {remote_id.decode()}: {e}. Message: {message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Unexpected error processing D Node registration from {remote_id.decode()}: {e}"
|
||||
)
|
||||
|
||||
except zmq.ZMQError as e: # type: ignore
|
||||
logger.error(
|
||||
f"DP Decode Proxy: ZMQ Error in D Node listener thread: {e}")
|
||||
if e.errno == zmq.ETERM: # type: ignore
|
||||
logger.info(
|
||||
"DP Decode Proxy: D Node Listener thread terminating.")
|
||||
break
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Unexpected error in D Node listener thread: {e}"
|
||||
)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _register_to_pd_proxy(pd_proxy_zmq_addr, my_http_addr, my_zmq_addr):
|
||||
context = None
|
||||
sock = None
|
||||
while True:
|
||||
try:
|
||||
if context is None:
|
||||
context = zmq.Context() # type: ignore
|
||||
if sock is None:
|
||||
sock = context.socket(zmq.DEALER) # type: ignore
|
||||
identity = f"dp_proxy_{my_http_addr}".encode('utf-8')
|
||||
sock.setsockopt(zmq.IDENTITY, identity) # type: ignore
|
||||
sock.setsockopt(zmq.LINGER, 0) # type: ignore
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Attempting to connect to PD Proxy at {pd_proxy_zmq_addr}..."
|
||||
)
|
||||
sock.connect(f"tcp://{pd_proxy_zmq_addr}")
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Connected to PD Proxy at {pd_proxy_zmq_addr}."
|
||||
)
|
||||
|
||||
data = {
|
||||
"type": "D",
|
||||
"http_address": my_http_addr,
|
||||
"zmq_address": my_zmq_addr
|
||||
}
|
||||
logger.debug(
|
||||
f"DP Decode Proxy: Sending registration/heartbeat to PD Proxy: {data}"
|
||||
)
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(5)
|
||||
|
||||
except zmq.ZMQError as e: # type: ignore
|
||||
logger.error(
|
||||
f"DP Decode Proxy: ZMQ Error connecting/sending to PD Proxy ({pd_proxy_zmq_addr}): {e}"
|
||||
)
|
||||
if sock:
|
||||
sock.close()
|
||||
sock = None
|
||||
time.sleep(10)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Unexpected error in PD Proxy registration thread: {e}"
|
||||
)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if sock:
|
||||
sock.close()
|
||||
sock = None
|
||||
time.sleep(10)
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
def start_zmq_thread(hostname, port, socket_type, target_func, thread_name):
|
||||
"""Generic ZMQ thread starter for ROUTER or PULL."""
|
||||
if not hostname:
|
||||
hostname = "0.0.0.0"
|
||||
context = zmq.Context.instance() # type: ignore
|
||||
socket = context.socket(socket_type)
|
||||
socket.setsockopt(zmq.LINGER, 0) # type: ignore
|
||||
try:
|
||||
socket.bind(f"tcp://{hostname}:{port}")
|
||||
except zmq.ZMQError as e: # type: ignore
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Error binding ZMQ {socket_type} socket to tcp://{hostname}:{port}: {e}"
|
||||
)
|
||||
socket.close()
|
||||
raise
|
||||
|
||||
poller = zmq.Poller() # type: ignore
|
||||
poller.register(socket, zmq.POLLIN) # type: ignore
|
||||
|
||||
thread = threading.Thread(target=target_func,
|
||||
args=(poller, socket),
|
||||
daemon=True,
|
||||
name=thread_name)
|
||||
thread.start()
|
||||
return thread, socket
|
||||
|
||||
|
||||
def start_thread_with_event_loop():
|
||||
global _idle_send_loop
|
||||
asyncio.set_event_loop(_idle_send_loop)
|
||||
_idle_send_loop.run_forever() # type: ignore
|
||||
|
||||
|
||||
async def forward_request_internal(url, data, request_id):
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization":
|
||||
f"Bearer {os.environ.get('OPENAI_API_KEY', '')}",
|
||||
"X-Request-Id": request_id,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content.iter_chunked(
|
||||
1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
error_content = await response.read()
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Error from D node {url} (status {response.status}): {error_content.decode(errors='ignore')}"
|
||||
)
|
||||
yield error_content
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Error forwarding request {request_id} to D node {url}: {e}"
|
||||
)
|
||||
error_msg = f"Failed to connect or communicate with D node at {url}: {e}".encode(
|
||||
'utf-8')
|
||||
yield error_msg
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
@app.route('/v1/completions', methods=['POST'])
|
||||
async def handle_request():
|
||||
global dp_instances
|
||||
global dp_cv
|
||||
global round_robin_index
|
||||
|
||||
request_received_id = request.headers.get("X-Request-Id")
|
||||
if not request_received_id:
|
||||
fallback_id = f"dp_fallback_{random_uuid()}"
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Received request without X-Request-Id header. Using fallback ID: {fallback_id}"
|
||||
)
|
||||
request_received_id = fallback_id
|
||||
else:
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Received request from PD Proxy, using propagated ID: {request_received_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
if not original_request_data:
|
||||
return await make_response("Request body must be valid JSON.", 400)
|
||||
|
||||
target_addr = None
|
||||
with dp_cv:
|
||||
if not dp_instances:
|
||||
logger.warning(
|
||||
f"DP Decode Proxy: Request {request_received_id}: No D Node instances available/registered."
|
||||
)
|
||||
return await make_response("No Decode instances available.",
|
||||
503)
|
||||
|
||||
dp_addresses = list(dp_instances.keys())
|
||||
if not dp_addresses:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Request {request_received_id}: Internal error - dp_instances populated but list is empty."
|
||||
)
|
||||
return await make_response("Internal Server Error", 500)
|
||||
|
||||
current_selection_index = round_robin_index % len(dp_addresses)
|
||||
target_addr = dp_addresses[current_selection_index]
|
||||
round_robin_index += 1
|
||||
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Request {request_received_id}: Routing Decode to D Node {target_addr} (Index {current_selection_index})"
|
||||
)
|
||||
|
||||
target_url = f'http://{target_addr}/v1/completions'
|
||||
|
||||
generator = forward_request_internal(target_url, original_request_data,
|
||||
request_received_id)
|
||||
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
if original_request_data.get("stream", False):
|
||||
response.headers['Content-Type'] = 'text/event-stream'
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
else:
|
||||
response.headers['Content-Type'] = 'application/json'
|
||||
|
||||
logger.debug(
|
||||
f"DP Decode Proxy: Request {request_received_id}: Streaming response from D node {target_addr}"
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DP Decode Proxy: Error handling request {request_received_id}: {e}"
|
||||
)
|
||||
return await make_response("Internal Server Error", 500)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
d_listener_thread, d_reg_socket = start_zmq_thread(
|
||||
"0.0.0.0",
|
||||
DP_PROXY_ZMQ_REG_PORT,
|
||||
zmq.ROUTER, # type: ignore
|
||||
_listen_for_d_register, # type: ignore
|
||||
"DP_DNodeListenerThread")
|
||||
|
||||
metadata_thread, notify_socket = start_zmq_thread(
|
||||
"0.0.0.0",
|
||||
DP_PROXY_ZMQ_NOTIFY_PORT,
|
||||
zmq.PULL, # type: ignore
|
||||
metadata_collect_trigger,
|
||||
"DP_MetadataMonitorThread")
|
||||
|
||||
_idle_send_loop = asyncio.new_event_loop()
|
||||
idle_loop_thread = threading.Thread(target=start_thread_with_event_loop,
|
||||
daemon=True,
|
||||
name="DP_IdleSendLoopThread")
|
||||
idle_loop_thread.start()
|
||||
|
||||
pd_register_thread = threading.Thread(target=_register_to_pd_proxy,
|
||||
args=(PD_PROXY_ADDRESS,
|
||||
MY_HTTP_ADDRESS,
|
||||
MY_ZMQ_ADDRESS_PLACEHOLDER),
|
||||
daemon=True,
|
||||
name="DP_PDRegisterThread")
|
||||
pd_register_thread.start()
|
||||
|
||||
logger.info(
|
||||
f"DP Decode Proxy: Starting Quart web server on http://0.0.0.0:{DP_PROXY_HTTP_PORT}"
|
||||
)
|
||||
zmq_context = zmq.Context.instance() # type: ignore
|
||||
try:
|
||||
app.run(host='0.0.0.0', port=DP_PROXY_HTTP_PORT)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("DP Decode Proxy: KeyboardInterrupt received, stopping...")
|
||||
except Exception as e:
|
||||
logger.error(f"DP Decode Proxy: Failed to run Quart server: {e}")
|
||||
finally:
|
||||
logger.info("DP Decode Proxy: Shutting down...")
|
||||
if _idle_send_loop and _idle_send_loop.is_running():
|
||||
logger.info("DP Decode Proxy: Stopping idle send loop...")
|
||||
_idle_send_loop.call_soon_threadsafe(_idle_send_loop.stop)
|
||||
|
||||
if d_reg_socket:
|
||||
d_reg_socket.close()
|
||||
if notify_socket:
|
||||
notify_socket.close()
|
||||
if zmq_context:
|
||||
zmq_context.term()
|
||||
|
||||
logger.info("DP Decode Proxy: Shutdown complete.")
|
||||
@@ -1,69 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
|
||||
#
|
||||
"""
|
||||
This file provides a function to obtain ips of all NPU Devices in current machine.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
# Get all device ips using hccn_tool
|
||||
HCCN_TOOL_PATH = envs.HCCN_PATH
|
||||
|
||||
|
||||
def get_device_ips():
|
||||
npu_info = subprocess.run(['npu-smi', 'info', '-m'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
|
||||
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
|
||||
|
||||
# Extract NPU IDs for all Ascend devices (excluding Mcu rows)
|
||||
device_ids = []
|
||||
for line in npu_info.stdout.strip().split('\n'):
|
||||
match = re.match(r'^\s*(\d+)\s+\d+\s+\d+\s+Ascend', line)
|
||||
if match:
|
||||
device_ids.append(int(match.group(1)))
|
||||
|
||||
if not device_ids:
|
||||
raise RuntimeError(
|
||||
"Cannot parse any valid device ID from npu-smi output.")
|
||||
|
||||
device_ip_list = []
|
||||
for device_id in device_ids:
|
||||
cmd = [HCCN_TOOL_PATH, '-i', str(device_id), '-ip', '-g']
|
||||
device_ip_info = subprocess.run(cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
ip_match = re.search(r'ipaddr:(.*)', device_ip_info.stdout)
|
||||
if not ip_match:
|
||||
raise RuntimeError(
|
||||
f"Cannot parse IP from hccn_tool for device {device_id}")
|
||||
device_ip = ip_match.group(1).strip()
|
||||
device_ip_list.append(device_ip)
|
||||
|
||||
return device_ip_list
|
||||
|
||||
|
||||
print(get_device_ips())
|
||||
@@ -1,196 +0,0 @@
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack # type: ignore
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
prefill_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
decode_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
|
||||
prefill_cv = threading.Condition()
|
||||
decode_cv = threading.Condition()
|
||||
|
||||
|
||||
def _listen_for_register(poller, router_socket):
|
||||
while True:
|
||||
socks = dict(poller.poll())
|
||||
if router_socket in socks:
|
||||
remote_address, message = router_socket.recv_multipart()
|
||||
# data: {"type": "P", "http_address": "ip:port",
|
||||
# "zmq_address": "ip:port"}
|
||||
data = msgpack.loads(message)
|
||||
if data["type"] == "P":
|
||||
global prefill_instances
|
||||
global prefill_cv
|
||||
with prefill_cv:
|
||||
prefill_instances[
|
||||
data["http_address"]] = data["zmq_address"]
|
||||
print(
|
||||
"Get a prefill register with http_addr %s and zmq_addr %s",
|
||||
data["http_address"],
|
||||
data["zmq_address"],
|
||||
)
|
||||
elif data["type"] == "D":
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
decode_instances[
|
||||
data["http_address"]] = data["zmq_address"]
|
||||
print(
|
||||
"Get a decode register with http_addr %s and zmq_addr %s",
|
||||
data["http_address"],
|
||||
data["zmq_address"],
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Unexpected, Received message from %s, data: %s",
|
||||
remote_address,
|
||||
data,
|
||||
)
|
||||
|
||||
|
||||
def start_service_discovery(hostname, port):
|
||||
if not hostname:
|
||||
hostname = socket.gethostname()
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
|
||||
context = zmq.Context() # type: ignore
|
||||
router_socket = context.socket(zmq.ROUTER) # type: ignore
|
||||
router_socket.bind(f"tcp://{hostname}:{port}")
|
||||
|
||||
poller = zmq.Poller() # type: ignore
|
||||
poller.register(router_socket, zmq.POLLIN) # type: ignore
|
||||
|
||||
_listener_thread = threading.Thread(target=_listen_for_register,
|
||||
args=[poller, router_socket],
|
||||
daemon=True)
|
||||
_listener_thread.start()
|
||||
return _listener_thread
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
async def forward_request(url, data, request_id):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
async with session.post(url=url, json=data,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
global prefill_instances
|
||||
global prefill_cv
|
||||
with prefill_cv:
|
||||
if len(prefill_instances) > 1:
|
||||
print(
|
||||
"Found more than 1 Prefill instances. Currently we only support 1P1D, so only"
|
||||
f"the first Prefill instance({list(prefill_instances.keys())[0]}) will be used!"
|
||||
)
|
||||
if len(prefill_instances) == 0:
|
||||
res_str = (
|
||||
"No Prefill instances has been registered to proxy. Please confirm that you have successfully"
|
||||
" and correctly started a Prefill vLLM instance.")
|
||||
print(res_str)
|
||||
response = await make_response(res_str)
|
||||
return response
|
||||
# prefill_addr, prefill_zmq_addr = random.choice(
|
||||
# list(prefill_instances.items()))
|
||||
prefill_addr, prefill_zmq_addr = list(prefill_instances.items())[0]
|
||||
print(
|
||||
"handle_request, prefill_addr: %s, zmq_addr: %s",
|
||||
prefill_addr,
|
||||
prefill_zmq_addr,
|
||||
)
|
||||
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
if len(decode_instances) > 1:
|
||||
print(
|
||||
"Found more than 1 Decode instances. Currently we only support 1P1D, so only"
|
||||
f"the first Decode instance({list(decode_instances.keys())[0]}) will be used!"
|
||||
)
|
||||
if len(decode_instances) == 0:
|
||||
res_str = (
|
||||
"No Decode instances has been registered to proxy. Please confirm that you have successfully"
|
||||
" and correctly started a Decode vLLM instance.")
|
||||
print(res_str)
|
||||
response = await make_response(res_str)
|
||||
return response
|
||||
# decode_addr, decode_zmq_addr = random.choice(
|
||||
# list(decode_instances.items()))
|
||||
decode_addr, decode_zmq_addr = list(decode_instances.items())[0]
|
||||
print(
|
||||
"handle_request, decode_addr: %s, zmq_addr: %s",
|
||||
decode_addr,
|
||||
decode_zmq_addr,
|
||||
)
|
||||
|
||||
request_id = f"___prefill_addr_{prefill_addr}___decode_addr_{decode_addr}_{random_uuid()}"
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request(f"http://{prefill_addr}/v1/completions",
|
||||
prefill_request, request_id):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
generator = forward_request(
|
||||
f"http://{decode_addr}/v1/completions",
|
||||
original_request_data,
|
||||
request_id,
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
description="args of disaggregated-prefill proxy")
|
||||
parser.add_argument("--http-port", type=int, default=10001)
|
||||
parser.add_argument("--register-port", type=int, default=10002)
|
||||
args = parser.parse_args()
|
||||
|
||||
t = start_service_discovery("0.0.0.0", args.register_port)
|
||||
app.run(host="0.0.0.0", port=args.http_port)
|
||||
t.join()
|
||||
@@ -1,37 +0,0 @@
|
||||
export HCCL_IF_IP=2.0.0.0
|
||||
export GLOO_SOCKET_IFNAME="enp189s0f0"
|
||||
export TP_SOCKET_IFNAME="enp189s0f0"
|
||||
export HCCL_SOCKET_IFNAME="enp189s0f0"
|
||||
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
|
||||
export VLLM_USE_V1=0
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
|
||||
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--tensor-parallel-size 8 \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 2000 \
|
||||
--max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "AscendSimpleConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 8,
|
||||
"kv_port":"21001",
|
||||
"kv_connector_extra_config":
|
||||
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
|
||||
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"proxy_ip":"3.0.0.0",
|
||||
"proxy_port":"30001",
|
||||
"http_port": 10002}
|
||||
}'
|
||||
@@ -1,37 +0,0 @@
|
||||
export HCCL_IF_IP=1.0.0.0
|
||||
export GLOO_SOCKET_IFNAME="enp189s0f0"
|
||||
export TP_SOCKET_IFNAME="enp189s0f0"
|
||||
export HCCL_SOCKET_IFNAME="enp189s0f0"
|
||||
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
|
||||
export VLLM_USE_V1=0
|
||||
|
||||
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
||||
|
||||
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
--host 0.0.0.0 \
|
||||
--port 10002 \
|
||||
--tensor-parallel-size 8 \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 2000 \
|
||||
--max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "AscendSimpleConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 8,
|
||||
"kv_port":"11001",
|
||||
"kv_connector_extra_config":
|
||||
{"prompt_device_ips": ["1.2.3.1", "1.2.3.2", "1.2.3.3", "1.2.3.4", "1.2.3.5", "1.2.3.6", "1.2.3.7", "1.2.3.8"],
|
||||
"decode_device_ips": ["1.2.3.9", "1.2.3.10", "1.2.3.11", "1.2.3.12", "1.2.3.13", "1.2.3.14", "1.2.3.15", "1.2.3.16"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"proxy_ip":"3.0.0.0",
|
||||
"proxy_port":"30001",
|
||||
"http_port": 10002}
|
||||
}'
|
||||
@@ -1,71 +0,0 @@
|
||||
import zlib
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
|
||||
int32_hash)
|
||||
|
||||
|
||||
class MockSimplePipe:
|
||||
|
||||
def __init__(self):
|
||||
self.cluster_id = 0
|
||||
self.send_tensor = MagicMock()
|
||||
self.recv_tensor = MagicMock()
|
||||
self.deallocate_buffer = MagicMock()
|
||||
|
||||
|
||||
class TestSimpleBuffer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.pipe = MockSimplePipe()
|
||||
self.buffer = SimpleBuffer(self.pipe)
|
||||
|
||||
def test_int32_hash(self):
|
||||
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))
|
||||
|
||||
def test_insert(self):
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = torch.tensor([1, 0, 1])
|
||||
key = torch.randn(2, 3, 4, 5)
|
||||
value = torch.randn(2, 3, 4, 5)
|
||||
hidden = torch.randn(3, 6)
|
||||
|
||||
self.buffer.num_layers = 2
|
||||
self.buffer.num_heads = 4
|
||||
self.buffer.head_size = 5
|
||||
self.buffer.hidden_size = 6
|
||||
self.buffer.dtype = torch.float32
|
||||
|
||||
self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")
|
||||
|
||||
self.pipe.send_tensor.assert_called()
|
||||
|
||||
def test_drop_select(self):
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = None
|
||||
|
||||
self.buffer.num_layers = 2
|
||||
self.buffer.num_heads = 4
|
||||
self.buffer.head_size = 5
|
||||
self.buffer.hidden_size = 6
|
||||
self.buffer.dtype = torch.float32
|
||||
|
||||
self.pipe.recv_tensor.side_effect = [
|
||||
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
|
||||
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
|
||||
(MagicMock(), torch.randn(1, 3, 6))
|
||||
]
|
||||
|
||||
result = self.buffer.drop_select(input_tokens, roi, "req1")
|
||||
self.assertEqual(len(result), 4)
|
||||
self.assertIsInstance(result[0], torch.Tensor)
|
||||
self.assertIsInstance(result[1], torch.Tensor)
|
||||
self.assertIsInstance(result[2], torch.Tensor)
|
||||
self.assertIsNone(result[3])
|
||||
self.assertEqual(result[0].shape, (2, 3, 4, 5))
|
||||
|
||||
def test_close(self):
|
||||
self.buffer.close()
|
||||
@@ -1,146 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
||||
from vllm_ascend.distributed.kv_transfer.simple_connector import \
|
||||
SimpleConnector
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
|
||||
class TestSimpleConnector(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_pipe = MagicMock(spec=SimplePipe)
|
||||
self.mock_buffer = MagicMock(spec=SimpleBuffer)
|
||||
|
||||
patcher = patch(
|
||||
'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer')
|
||||
self.addCleanup(patcher.stop)
|
||||
self.MockSimpleBuffer = patcher.start()
|
||||
self.MockSimpleBuffer.return_value = self.mock_buffer
|
||||
|
||||
def _create_mock_config(self, kv_role):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
mock_config.kv_port = 5500
|
||||
self.mock_config = MagicMock(spec=VllmConfig)
|
||||
self.mock_config.kv_transfer_config.is_kv_producer = True
|
||||
self.mock_config.model_config.hf_config.hidden_size = 128
|
||||
self.mock_config.model_config.hf_config.num_attention_heads = 8
|
||||
self.mock_config.model_config.hf_config.num_key_value_heads = 8
|
||||
self.mock_config.model_config.hf_config.qk_rope_head_dim = 16
|
||||
self.mock_config.model_config.hf_config.kv_lora_rank = 16
|
||||
self.mock_config.model_config.is_deepseek_mla = True
|
||||
# 模拟 parallel_config
|
||||
self.mock_config.parallel_config = MagicMock()
|
||||
self.mock_config.parallel_config.tensor_parallel_size = 1
|
||||
self.mock_config.parallel_config.get_num_layers.return_value = 4
|
||||
|
||||
if kv_role == "kv_producer":
|
||||
self.mock_config.kv_transfer_config.kv_role = "kv_producer"
|
||||
else:
|
||||
self.mock_config.kv_transfer_config.kv_role = "kv_consumer"
|
||||
return mock_config
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
"""Test select method when buffer retrieval succeeds."""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
assert connector.producer_data_pipe is not None
|
||||
assert connector.producer_buffer is not None
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_consumer"))
|
||||
connector.consumer_data_pipe = mock_pipe
|
||||
connector.consumer_buffer = mock_buffer
|
||||
assert connector.consumer_data_pipe is not None
|
||||
assert connector.consumer_buffer is not None
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = torch.tensor([True, True, True])
|
||||
req_id = "test_req"
|
||||
connector.select(input_tokens, roi, req_id)
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
"""Test insert operation"""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
|
||||
connector.producer_buffer = mock_buffer
|
||||
|
||||
input_tokens = torch.randint(0, 1000, (5, ))
|
||||
roi = torch.ones_like(input_tokens, dtype=torch.bool)
|
||||
keys = torch.randn(3, 5, 1, 96)
|
||||
values = torch.randn(3, 5, 1, 96)
|
||||
hidden = torch.randn(5, 768)
|
||||
req_id = "test_req"
|
||||
|
||||
connector.insert(input_tokens, roi, keys, values, hidden, req_id)
|
||||
|
||||
mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys,
|
||||
values, hidden, req_id)
|
||||
|
||||
@patch.object(SimpleConnector, 'insert')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer,
|
||||
MockLLMDataDist, mock_insert,
|
||||
mock_rank):
|
||||
"""Test sending KV caches and hidden states"""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
|
||||
mock_model_executable = MagicMock()
|
||||
mock_model_executable.model.start_layer = 0
|
||||
mock_model_executable.model.end_layer = 3
|
||||
|
||||
mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata)
|
||||
mock_model_input.input_tokens = torch.randint(0, 1000, (10, ))
|
||||
mock_model_input.attn_metadata.seq_lens = [5, 5]
|
||||
mock_model_input.attn_metadata.slot_mapping = torch.randint(
|
||||
0, 100, (10, ))
|
||||
mock_model_input.attn_metadata.num_prefill_tokens = 10
|
||||
mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]}
|
||||
|
||||
kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)]
|
||||
|
||||
hidden_states = torch.randn(10, 768)
|
||||
|
||||
connector.send_kv_caches_and_hidden_states(mock_model_executable,
|
||||
mock_model_input, kv_caches,
|
||||
hidden_states)
|
||||
@@ -1,145 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
|
||||
class TestSimplePipe(TestBase):
|
||||
|
||||
@classmethod
|
||||
def _create_mock_config(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
mock_config.kv_port = 5500
|
||||
return mock_config
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_init_success(self, mock_thread, MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=self._create_mock_config(),
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
def test_init_with_invalid_kv_role(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "err_role"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
def test_init_with_missing_device_ips(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=0,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_empty(self, MockThread,
|
||||
MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_not_empty(
|
||||
self, MockThread, MockLLMDataDist):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": [""],
|
||||
"decode_device_ips": [""],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
|
||||
MockLLMDataDist):
|
||||
pipe = MockSimplePipe()
|
||||
tensor = torch.randn(3, 3)
|
||||
tensor_desc = MockLLMDataDist.CacheDesc(
|
||||
num_tensors=1,
|
||||
shape=(3, 3),
|
||||
data_type=MockLLMDataDist.DataType.DT_FLOAT,
|
||||
seq_len_dim_index=1)
|
||||
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
|
||||
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
|
||||
self.assertIsNotNone(result)
|
||||
@@ -18,14 +18,6 @@
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
|
||||
"LLMDataDistConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendSimpleConnector",
|
||||
"vllm_ascend.distributed.kv_transfer.simple_connector", "SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LLMDataDistCMgrConnector",
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector",
|
||||
|
||||
@@ -1,207 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import zlib
|
||||
from typing import List, Optional
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import torch
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import \
|
||||
KVLookupBufferBase
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
from vllm_ascend.distributed.kv_transfer.utils import TORCH_DTYPE_TO_NPU_DTYPE
|
||||
|
||||
|
||||
# Hash a string into a int32 value.
|
||||
def int32_hash(data):
|
||||
assert isinstance(data, str)
|
||||
data = data.encode("utf-8")
|
||||
return zlib.adler32(data)
|
||||
|
||||
|
||||
class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
def __init__(self, data_pipe: SimplePipe):
|
||||
self.data_pipe = data_pipe
|
||||
# Consumer buffer need these information to construct receiving buffer.
|
||||
self.num_layers = None
|
||||
self.num_heads = None
|
||||
self.head_size = None
|
||||
self.dtype = None
|
||||
self.hidden_size = None
|
||||
self.key_buffer = None
|
||||
self.value_buffer = None
|
||||
self.hidden_buffer = None
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
req_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
seq_len: num_tokens of current request.
|
||||
input_tokens: [seq_len]
|
||||
roi: [seq_len]
|
||||
key: [num_layers, seq_len, num_kv_heads, head_size]
|
||||
value: [num_layers, seq_len, num_kv_heads, head_size]
|
||||
hidden: [seq_len, hidden_size]
|
||||
"""
|
||||
orig_k_shape = key.shape
|
||||
num_layers = orig_k_shape[0]
|
||||
|
||||
# unsequeeze all tensors to make first dim to 1.
|
||||
# This is because D node can only pull one batch data from P.
|
||||
# So we make first dim to 1 here in order to pull full data.
|
||||
key = key.view(num_layers, -1).unsqueeze(0)
|
||||
value = value.view(num_layers, -1).unsqueeze(0)
|
||||
hidden = hidden.unsqueeze(0)
|
||||
|
||||
hidden_dtype = key.dtype
|
||||
# initialize LLMDatadist data structure
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
key.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=1,
|
||||
)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
value.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=1,
|
||||
)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden.shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[hidden_dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
|
||||
req_id = int32_hash(req_id)
|
||||
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 1)
|
||||
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 2)
|
||||
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 3)
|
||||
|
||||
# Currently we use hash value of request id as key, so no need to send input_tokens
|
||||
self.key_buffer = self.data_pipe.send_tensor(key, key_desc,
|
||||
key_cache_key)
|
||||
self.value_buffer = self.data_pipe.send_tensor(value, value_desc,
|
||||
value_cache_key)
|
||||
self.hidden_buffer = self.data_pipe.send_tensor(
|
||||
hidden, hidden_desc, hidden_cache_key)
|
||||
|
||||
def drop_select(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: Optional[torch.Tensor],
|
||||
req_id: str,
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
"""Select and *drop* KV cache entries from the lookup buffer.
|
||||
|
||||
The functionality is similar to the following python statements
|
||||
```
|
||||
ret = buffer.pop(input_tokens, roi)
|
||||
return ret
|
||||
```
|
||||
|
||||
Args:
|
||||
input_tokens (torch.Tensor): token IDs.
|
||||
roi (torch.Tensor): A binary mask on top of the input tokens
|
||||
|
||||
Returns:
|
||||
A list of tensors including:
|
||||
key: [num_layers, num_tokens, num_heads, head_size]
|
||||
value: [num_layers, num_tokens, num_heads, head_size]
|
||||
hidden_or_intermediate_states: [num_tokens, hidden_size]
|
||||
roi: None (Currently we don't supported roi)
|
||||
"""
|
||||
orig_req_id = req_id
|
||||
req_id = int32_hash(req_id)
|
||||
num_tokens = input_tokens.shape[0]
|
||||
kv_shape = (
|
||||
1,
|
||||
self.num_layers,
|
||||
num_tokens * self.num_heads * self.head_size,
|
||||
)
|
||||
hidden_shape = (1, num_tokens, self.hidden_size)
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[self.dtype],
|
||||
seq_len_dim_index=-1,
|
||||
)
|
||||
|
||||
key_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 1)
|
||||
value_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 2)
|
||||
hidden_cache_key = llm_datadist.CacheKey(self.data_pipe.cluster_id,
|
||||
req_id, 3)
|
||||
|
||||
# Deallocate buffer allocated in last round.
|
||||
if self.key_buffer:
|
||||
try:
|
||||
self.data_pipe.deallocate_buffer(self.key_buffer)
|
||||
self.data_pipe.deallocate_buffer(self.value_buffer)
|
||||
self.data_pipe.deallocate_buffer(self.hidden_buffer)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to free kv cache buffer, Error code: {str(e)}")
|
||||
|
||||
try:
|
||||
self.key_buffer, key = self.data_pipe.recv_tensor(
|
||||
key_desc, key_cache_key)
|
||||
self.value_buffer, value = self.data_pipe.recv_tensor(
|
||||
value_desc, value_cache_key)
|
||||
self.hidden_buffer, hidden = self.data_pipe.recv_tensor(
|
||||
hidden_desc, hidden_cache_key)
|
||||
key = key.view(self.num_layers, num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.view(self.num_layers, num_tokens, self.num_heads,
|
||||
self.head_size)
|
||||
hidden = hidden.view(num_tokens, self.hidden_size)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Fail to receive kv cache and hidden states of request: {orig_req_id} "
|
||||
f"Error is {str(e)}")
|
||||
return [None, None, None, None]
|
||||
|
||||
return [key, value, hidden, roi]
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
@@ -1,379 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as vllm_envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class SimpleConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.model_config = config.model_config.hf_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.is_deepseek_mla = config.model_config.is_deepseek_mla
|
||||
self.use_mla_opt = not vllm_envs.VLLM_MLA_DISABLE
|
||||
self.n_layer = self.config.model_config.get_num_layers(
|
||||
self.config.parallel_config)
|
||||
|
||||
self.producer_data_pipe: Optional[SimplePipe]
|
||||
self.consumer_data_pipe: Optional[SimplePipe]
|
||||
|
||||
self.producer_buffer: Optional[SimpleBuffer]
|
||||
self.consumer_buffer: Optional[SimpleBuffer]
|
||||
|
||||
if self.config.kv_transfer_config.is_kv_producer:
|
||||
self.producer_data_pipe = SimplePipe(
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
kv_transfer_config=config.kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=rank,
|
||||
)
|
||||
self.producer_buffer = SimpleBuffer(self.producer_data_pipe)
|
||||
else:
|
||||
self.consumer_data_pipe = SimplePipe(
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
kv_transfer_config=config.kv_transfer_config,
|
||||
hostname="",
|
||||
port_offset=rank,
|
||||
)
|
||||
self.consumer_buffer = SimpleBuffer(self.consumer_data_pipe)
|
||||
|
||||
def select(
|
||||
self,
|
||||
input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor],
|
||||
req_id: str,
|
||||
) -> List[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.consumer_buffer is not None, (
|
||||
"Please initialize the "
|
||||
"consumer buffer before calling select.")
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi, req_id)
|
||||
|
||||
def insert(
|
||||
self,
|
||||
input_tokens: torch.Tensor,
|
||||
roi: torch.Tensor,
|
||||
keys: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
hidden: torch.Tensor,
|
||||
req_id: str,
|
||||
) -> None:
|
||||
|
||||
assert self.producer_buffer is not None, (
|
||||
"Please initialize the "
|
||||
"producer buffer before calling insert.")
|
||||
self.producer_buffer.insert(input_tokens, roi, keys, values, hidden,
|
||||
req_id)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = self.model_config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
|
||||
# Deepseek's MLA (Multi-head Latent Attention) uses two different
|
||||
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
|
||||
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
|
||||
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
|
||||
# kv_lora_rank + qk_rope_head_dim].
|
||||
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
|
||||
# to a kv_cache shape of [2, num_blks, blk_size,
|
||||
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
|
||||
# For more details, see vllm/attention/backends/mla/common.py.
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = (model_config.kv_lora_rank +
|
||||
model_config.qk_rope_head_dim)
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = (model_config.qk_nope_head_dim +
|
||||
model_config.qk_rope_head_dim)
|
||||
else:
|
||||
head_size = getattr(
|
||||
model_config,
|
||||
"head_dim",
|
||||
int(hidden_size // num_attention_heads),
|
||||
)
|
||||
# Enumerate over all requests and insert them one by one.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
key_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache.reshape(-1, num_heads, head_size)
|
||||
else:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
|
||||
# shape: [num_layers, num_tokens, num_heads, head_size]
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
||||
# Currently we haven't considered situation of roi, pass None here.
|
||||
self.insert(
|
||||
current_tokens,
|
||||
None,
|
||||
keys,
|
||||
values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos],
|
||||
cur_req_id,
|
||||
)
|
||||
|
||||
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Tuple[
|
||||
Union[torch.Tensor, IntermediateTensors],
|
||||
bool,
|
||||
"ModelInputForGPUWithSamplingMetadata",
|
||||
]:
|
||||
bypass_model_exec = True
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
# get model config
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
num_heads, head_dim = kv_caches[0].shape[-2:]
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
num_layers = end_layer - start_layer
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
head_size = (model_config.kv_lora_rank +
|
||||
model_config.qk_rope_head_dim)
|
||||
num_heads = 1
|
||||
elif self.is_deepseek_mla and not self.use_mla_opt:
|
||||
head_size = (model_config.qk_nope_head_dim +
|
||||
model_config.qk_rope_head_dim)
|
||||
else:
|
||||
head_size = getattr(
|
||||
model_config,
|
||||
"head_dim",
|
||||
int(hidden_size // num_attention_heads),
|
||||
)
|
||||
self.consumer_buffer.num_heads = num_heads # type: ignore
|
||||
self.consumer_buffer.num_layers = num_layers # type: ignore
|
||||
self.consumer_buffer.head_size = head_size # type: ignore
|
||||
self.consumer_buffer.dtype = kv_caches[0].dtype # type: ignore
|
||||
self.consumer_buffer.hidden_size = hidden_size # type: ignore
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
|
||||
total_tokens = model_input.attn_metadata.num_prefill_tokens + model_input.attn_metadata.num_decode_tokens
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# enumerate different requests
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
cur_req_id = list(model_input.request_ids_to_seq_ids.keys())[idx]
|
||||
|
||||
ret = self.select(
|
||||
current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool),
|
||||
cur_req_id,
|
||||
)
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
bypass_model_exec = False
|
||||
num_computed_tokens_list.append(0)
|
||||
continue
|
||||
|
||||
keys: torch.Tensor = ret[0]
|
||||
values: torch.Tensor = ret[1]
|
||||
hidden: torch.Tensor = ret[2]
|
||||
|
||||
num_computed_tokens = keys.shape[1]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(
|
||||
model_executable.model.start_layer,
|
||||
model_executable.model.end_layer,
|
||||
):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if self.is_deepseek_mla and self.use_mla_opt:
|
||||
layer.self_attn.attn = layer.self_attn.mla_attn
|
||||
key_cache = kv_cache
|
||||
slots = slot_mapping[start_pos:end_pos]
|
||||
sliced_key = keys[i - model_executable.model.start_layer]
|
||||
torch_npu._npu_reshape_and_cache_siso(key=sliced_key,
|
||||
key_cache=key_cache,
|
||||
slot_indices=slots)
|
||||
else:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
sliced_key = keys[i - model_executable.model.start_layer]
|
||||
sliced_value = values[i -
|
||||
model_executable.model.start_layer]
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=sliced_key,
|
||||
value=sliced_value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slot_mapping[start_pos:end_pos],
|
||||
)
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
if get_dp_group().world_size > 1:
|
||||
bypass_model_exec = True
|
||||
hidden_or_intermediate_states = torch.empty(
|
||||
[total_tokens, hidden_size],
|
||||
dtype=kv_caches[0].dtype,
|
||||
device=kv_caches[0].device)
|
||||
logger.warning(
|
||||
"[Detect there is more one DP rank in this decode node, in this scenario, no recompute is expected when kv cache dose not received.]"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.",
|
||||
torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.",
|
||||
torch.distributed.get_rank(),
|
||||
)
|
||||
# Can't directly concat here which might cause error when bs = 1.
|
||||
# hidden_or_intermediate_states = torch.empty(total_num_tokens, hidden_size, dtype=kv_caches[0].dtype, device=kv_caches[0].device)
|
||||
if len(hidden_or_intermediate_states_for_one_req) == 1:
|
||||
hidden = hidden_or_intermediate_states_for_one_req[0]
|
||||
tmp_indice = torch.tensor([0] * hidden.shape[0],
|
||||
dtype=torch.int64).npu()
|
||||
hidden_or_intermediate_states = torch.empty_like(hidden)
|
||||
torch_npu.scatter_update_(
|
||||
hidden_or_intermediate_states,
|
||||
tmp_indice,
|
||||
hidden,
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.producer_data_pipe.close() # type: ignore
|
||||
self.consumer_data_pipe.close() # type: ignore
|
||||
self.producer_buffer.close() # type: ignore
|
||||
self.consumer_buffer.close() # type: ignore
|
||||
@@ -1,207 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
import msgpack # type: ignore
|
||||
import torch
|
||||
import torch_npu
|
||||
import torchair # type: ignore
|
||||
import zmq # type: ignore
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import logger
|
||||
from vllm.utils import get_ip
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.distributed.kv_transfer.utils import NPU_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
|
||||
class SimplePipe(KVPipeBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank,
|
||||
local_rank,
|
||||
kv_transfer_config,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0, # NPU offset in current P/D instance.
|
||||
):
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
# Currently for 1P1D situation, we use cluster_id=0 for both Prefill and Decode
|
||||
# Will change here in the future to support xPyD.
|
||||
self.cluster_id = 0
|
||||
self.config = kv_transfer_config
|
||||
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
kv_role = kv_transfer_config.kv_role
|
||||
if kv_role == "kv_producer":
|
||||
self.role = llm_datadist.LLMRole.PROMPT
|
||||
elif kv_role == "kv_consumer":
|
||||
self.role = llm_datadist.LLMRole.DECODER
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"kv_role should be inside [kv_producer, kv_consumer]")
|
||||
|
||||
prefill_device_ips = kv_connector_extra_config.get(
|
||||
"prefill_device_ips", None)
|
||||
decode_device_ips = kv_connector_extra_config.get(
|
||||
"decode_device_ips", None)
|
||||
if prefill_device_ips is None or decode_device_ips is None:
|
||||
raise ValueError(
|
||||
"Please specify prefill_device_ips and decode_device_ips"
|
||||
"in kv_transfer_config.kv_connector_extra_config")
|
||||
p_device_num = len(prefill_device_ips)
|
||||
d_device_num = len(decode_device_ips)
|
||||
# When number of devices in P and D is not equal,
|
||||
# we assume that device in D can be mapped to any device in P.
|
||||
self.p_device_rank = self.rank % p_device_num
|
||||
self.d_device_rank = self.rank % d_device_num
|
||||
|
||||
self.prompt_ip_list = prefill_device_ips
|
||||
self.decode_ip_list = decode_device_ips
|
||||
self.llmdatadist_comm_port = kv_connector_extra_config.get(
|
||||
"llmdatadist_comm_port", 26000)
|
||||
# LLMDataDist initializing.
|
||||
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
|
||||
self._prepare_data_dist()
|
||||
# Decoder needs to initialize and link cluster
|
||||
if self.role == llm_datadist.LLMRole.DECODER:
|
||||
self.cluster = self._make_cluster()
|
||||
_, ret = self.data_dist.link_clusters([self.cluster], 20000)
|
||||
logger.info(
|
||||
f"rank {self.rank}, local_rank {self.local_rank} link, ret={ret}"
|
||||
)
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + str(proxy_port)
|
||||
|
||||
self._register_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
# Initialize zmq socket and register to proxy.
|
||||
# Note that only NPU 0 of each P/D instance register to proxy.
|
||||
if not hostname:
|
||||
hostname = get_ip() # Get ip of current host.
|
||||
port = int(kv_transfer_config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
self.context = zmq.Context() # type: ignore
|
||||
self.router_socket = self.context.socket(
|
||||
zmq.ROUTER) # type: ignore
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
# The `http_port` must be consistent with the serving port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
self._register_thread = threading.Thread(
|
||||
target=self._register_to_proxy, daemon=True)
|
||||
self._register_thread.start()
|
||||
|
||||
def _prepare_data_dist(self):
|
||||
options = {
|
||||
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
|
||||
}
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
options["llm.listenIpInfo"] = (
|
||||
f"{self.prompt_ip_list[self.p_device_rank]}:{self.llmdatadist_comm_port}"
|
||||
)
|
||||
else:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
print(f"prepare datadist, options: {options}")
|
||||
self.data_dist.init(options)
|
||||
self.kv_transfer = self.data_dist.kv_cache_manager
|
||||
print(f"{self.rank} rank data dist is ready")
|
||||
|
||||
def _make_cluster(self):
|
||||
cluster = llm_datadist.LLMClusterInfo()
|
||||
cluster.remote_cluster_id = self.cluster_id
|
||||
local_ip = self.decode_ip_list[self.d_device_rank]
|
||||
remote_ip = self.prompt_ip_list[self.p_device_rank]
|
||||
cluster.append_local_ip_info(local_ip, 0)
|
||||
cluster.append_remote_ip_info(remote_ip, self.llmdatadist_comm_port)
|
||||
return cluster
|
||||
|
||||
def _register_to_proxy(self):
|
||||
sock = self.context.socket(zmq.DEALER) # type: ignore
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) # type: ignore
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address,
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor: Optional[torch.Tensor],
|
||||
tensor_desc: llm_datadist.CacheDesc,
|
||||
tensor_key: llm_datadist.CacheKey,
|
||||
) -> llm_datadist.Cache:
|
||||
buffer = self.kv_transfer.allocate_cache(tensor_desc, [tensor_key])
|
||||
buffer_addr = buffer.per_device_tensor_addrs[0]
|
||||
data_tensor = torchair.llm_datadist.create_npu_tensors(
|
||||
tensor_desc.shape, tensor.dtype, buffer_addr)[0] # type: ignore
|
||||
update_indices = torch.tensor(
|
||||
[0] * tensor.shape[0], # type: ignore
|
||||
dtype=torch.int64).npu()
|
||||
torch_npu.scatter_update_(data_tensor, update_indices, tensor, axis=-1)
|
||||
# Free cache_id of buffer, actual deallocate will happen after consumer performing pull_cache.
|
||||
self.kv_transfer.deallocate_cache(buffer)
|
||||
return buffer
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_desc: llm_datadist.CacheDesc,
|
||||
tensor_key: llm_datadist.CacheKey,
|
||||
) -> llm_datadist.Cache:
|
||||
"""Note that this function only creates empty tensor on buffer addr and returns it."""
|
||||
tmp_buffer = self.kv_transfer.allocate_cache(tensor_desc)
|
||||
buffer_addr = tmp_buffer.per_device_tensor_addrs[0]
|
||||
data_tensor = torchair.llm_datadist.create_npu_tensors(
|
||||
tensor_desc.shape,
|
||||
NPU_DTYPE_TO_TORCH_DTYPE[tensor_desc.data_type],
|
||||
buffer_addr,
|
||||
)[0]
|
||||
self.kv_transfer.pull_cache(tensor_key, tmp_buffer, 0)
|
||||
# tmp_buffer is allocated without key and will be deallocated here immediately.
|
||||
# Free buffer here will cause accuracy problem.
|
||||
# self.kv_transfer.deallocate_cache(tmp_buffer)
|
||||
return tmp_buffer, data_tensor
|
||||
|
||||
def deallocate_buffer(self, buffer: llm_datadist.Cache):
|
||||
self.kv_transfer.deallocate_cache(buffer)
|
||||
|
||||
def close(self):
|
||||
self.data_dist.unlink_clusters([self.cluster], 5000)
|
||||
@@ -1,40 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import llm_datadist # type: ignore
|
||||
import torch
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32,
|
||||
}
|
||||
|
||||
NPU_DTYPE_TO_TORCH_DTYPE = {
|
||||
llm_datadist.DataType.DT_FLOAT16: torch.half,
|
||||
llm_datadist.DataType.DT_FLOAT16: torch.float16,
|
||||
llm_datadist.DataType.DT_BF16: torch.bfloat16,
|
||||
llm_datadist.DataType.DT_FLOAT: torch.float,
|
||||
llm_datadist.DataType.DT_FLOAT: torch.float32,
|
||||
llm_datadist.DataType.DT_INT8: torch.int8,
|
||||
llm_datadist.DataType.DT_INT64: torch.int64,
|
||||
llm_datadist.DataType.DT_INT32: torch.int32,
|
||||
}
|
||||
@@ -1,470 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torchair # type: ignore
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
# Get all device ips using hccn_tool
|
||||
HCCN_TOOL_PATH = envs.HCCN_PATH
|
||||
|
||||
|
||||
def get_device_ips():
|
||||
world_size = 8
|
||||
npu_info = subprocess.run(['npu-smi', 'info', '-m'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
|
||||
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
|
||||
re_result = re.match(r'.*\n\t([0-9]+).*', npu_info.stdout)
|
||||
if re_result is None:
|
||||
raise RuntimeError("Can't find npu start index")
|
||||
npu_start_idx = int(re_result.group(1))
|
||||
device_ip_list = []
|
||||
for ip_offset in range(world_size):
|
||||
cmd = [
|
||||
HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g'
|
||||
]
|
||||
device_ip_info = subprocess.run(cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
re_result = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout)
|
||||
if re_result is None:
|
||||
raise RuntimeError("Can't find npu ip")
|
||||
device_ip = re_result.group(1)
|
||||
device_ip_list.append(device_ip)
|
||||
return device_ip_list
|
||||
|
||||
|
||||
class KVTransferEngine:
|
||||
|
||||
def __init__(self, world_size, n_layer, role, local_rank):
|
||||
self.world_size = world_size
|
||||
self.n_layer = n_layer
|
||||
self.role = role
|
||||
self.device_ip_list = get_device_ips()
|
||||
self.local_rank = local_rank
|
||||
self.cluster_id = local_rank
|
||||
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
|
||||
|
||||
prompt_device_ids = envs.PROMPT_DEVICE_ID
|
||||
decode_device_ids = envs.DECODE_DEVICE_ID
|
||||
if prompt_device_ids is None or decode_device_ids is None:
|
||||
raise ValueError(
|
||||
"Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID")
|
||||
|
||||
prompt_ids = [
|
||||
int(x.strip()) for x in prompt_device_ids.split(",") if x.strip()
|
||||
]
|
||||
decode_ids = [
|
||||
int(x.strip()) for x in decode_device_ids.split(",") if x.strip()
|
||||
]
|
||||
|
||||
self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids]
|
||||
self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids]
|
||||
|
||||
def prepare_data_dist(self):
|
||||
options = {
|
||||
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
|
||||
}
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
options[
|
||||
"llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}"
|
||||
else:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
self.data_dist.init(options)
|
||||
self.kv_transfer = self.data_dist.kv_cache_manager
|
||||
logger.info(
|
||||
f"{self.local_rank}/{self.world_size} rank data dist is ready")
|
||||
|
||||
def make_cluster(self, prefill_ip, cluster_id=-1):
|
||||
cluster = llm_datadist.LLMClusterInfo()
|
||||
cluster.remote_cluster_id = cluster_id
|
||||
local_ip = self.decode_ip_list[self.local_rank]
|
||||
remote_ip = prefill_ip
|
||||
cluster.append_local_ip_info(local_ip, 0)
|
||||
cluster.append_remote_ip_info(remote_ip, 26000)
|
||||
return cluster
|
||||
|
||||
|
||||
class LLMDataDistConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
|
||||
if self.config.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.role = llm_datadist.LLMRole.PROMPT
|
||||
elif self.config.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.role = llm_datadist.LLMRole.DECODER
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"kv_role should be inside [kv_producer, kv_consumer]")
|
||||
|
||||
self.world_size = self.config.parallel_config.world_size
|
||||
self.n_layer = self.config.model_config.get_num_layers(
|
||||
self.config.parallel_config)
|
||||
|
||||
self.llm_datadist_engine = KVTransferEngine(self.world_size,
|
||||
self.n_layer, self.role,
|
||||
self.local_rank)
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
self.llm_datadist_engine.prepare_data_dist()
|
||||
else:
|
||||
self.llm_datadist_engine.prepare_data_dist()
|
||||
self.cluster = self.llm_datadist_engine.make_cluster(
|
||||
self.llm_datadist_engine.prompt_ip_list[self.local_rank],
|
||||
self.llm_datadist_engine.cluster_id)
|
||||
_, ret = self.llm_datadist_engine.data_dist.link_clusters(
|
||||
[self.cluster], 20000)
|
||||
logger.info(f"local_rank {self.local_rank} link, ret={ret}")
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors]
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
|
||||
num_layer = end_layer - start_layer
|
||||
|
||||
# Get shape of input_tokens_tensor and kv_cache
|
||||
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
|
||||
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
|
||||
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
|
||||
|
||||
assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype
|
||||
kv_hidden_dtype = kv_caches[0].dtype
|
||||
input_dtype = torch.int32
|
||||
|
||||
# initialize LLMDatadist data structure
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=1)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=1)
|
||||
input_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
input_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
|
||||
key_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1)
|
||||
]
|
||||
value_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2)
|
||||
]
|
||||
input_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3)
|
||||
]
|
||||
hidden_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4)
|
||||
]
|
||||
|
||||
self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
key_desc, key_cache_keys)
|
||||
self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
value_desc, value_cache_keys)
|
||||
self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
input_desc, input_cache_keys)
|
||||
self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
hidden_desc, hidden_cache_keys)
|
||||
|
||||
key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0]
|
||||
value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0]
|
||||
input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0]
|
||||
hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0]
|
||||
|
||||
self.key_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
key_desc.shape, kv_hidden_dtype, key_buffer_addr)
|
||||
self.value_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
value_desc.shape, kv_hidden_dtype, value_buffer_addr)
|
||||
self.input_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
input_desc.shape, input_dtype, input_buffer_addr)
|
||||
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr)
|
||||
|
||||
indices = torch.tensor([0], dtype=torch.int64).npu()
|
||||
|
||||
# copy cache data into llm datadist cache using scatter update
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos].to(
|
||||
torch.int32)
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
key_cache = kv_cache[0].view(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].view(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
# copy key into datadist
|
||||
k = self.key_cache[layer_id][:, start_pos:end_pos, :, :]
|
||||
new_k = key_cache[current_slot_mapping].unsqueeze(0)
|
||||
torch_npu.scatter_update_(k, indices, new_k, axis=-2)
|
||||
|
||||
# copy value into datadist
|
||||
val = self.value_cache[layer_id][:, start_pos:end_pos, :, :]
|
||||
new_val = value_cache[current_slot_mapping].unsqueeze(0)
|
||||
torch_npu.scatter_update_(val, indices, new_val, axis=-2)
|
||||
|
||||
# copy input into datadist
|
||||
inp = self.input_cache[0][:, start_pos:end_pos, :, :]
|
||||
new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1)
|
||||
torch_npu.scatter_update_(inp, indices, new_inp, axis=-2)
|
||||
|
||||
# copy hidden into datadist
|
||||
hid = self.hidden_cache[0][:, start_pos:end_pos, :, :]
|
||||
hid_shape0, hid_shape1 = hidden_or_intermediate_states[
|
||||
start_pos:end_pos].shape
|
||||
new_hid = hidden_or_intermediate_states[start_pos:end_pos].view(
|
||||
1, hid_shape0, 1, hid_shape1)
|
||||
torch_npu.scatter_update_(hid, indices, new_hid, axis=-2)
|
||||
|
||||
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# get model config
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
num_layer = end_layer - start_layer
|
||||
|
||||
# get input_tensor_shape and hidden_shape
|
||||
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
|
||||
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
|
||||
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
|
||||
|
||||
kv_hidden_dtype = kv_caches[0].dtype
|
||||
input_dtype = torch.int32
|
||||
|
||||
# Add LLM DataDist initialization
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
input_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
input_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
key_desc)
|
||||
self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
value_desc)
|
||||
self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
input_desc)
|
||||
self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
hidden_desc)
|
||||
key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0]
|
||||
value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
self.key_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
key_desc.shape, kv_hidden_dtype, key_buffer_addrs)
|
||||
self.value_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
value_desc.shape, kv_hidden_dtype, value_buffer_addrs)
|
||||
self.input_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
input_desc.shape, input_dtype, input_buffer_addrs)
|
||||
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs)
|
||||
|
||||
key_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 1, 0)
|
||||
value_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 2, 0)
|
||||
input_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 3, 0)
|
||||
hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 4, 0)
|
||||
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
key_cache_key, self.decode_key_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
value_cache_key, self.decode_value_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
input_cache_key, self.decode_input_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
hidden_cache_key, self.decode_hidden_buffer, 0)
|
||||
|
||||
keys = self.key_cache
|
||||
values = self.value_cache
|
||||
inputs = self.input_cache
|
||||
hidden = self.hidden_cache
|
||||
|
||||
# enumerate different requests
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
num_computed_tokens = inputs[0][0, start_pos:end_pos, 0,
|
||||
0].shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
sliced_key = keys[i - model_executable.model.start_layer][
|
||||
0, start_pos:end_pos, :, :]
|
||||
sliced_value = values[i - model_executable.model.start_layer][
|
||||
0, start_pos:end_pos, :, :]
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=sliced_key,
|
||||
value=sliced_value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slot_mapping[start_pos:end_pos])
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(
|
||||
hidden[0][0, start_pos:end_pos, 0, :])
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.info(
|
||||
"[rank%d][D]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
else:
|
||||
logger.info(
|
||||
"[rank%d][D]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self, ):
|
||||
self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster],
|
||||
5000)
|
||||
Reference in New Issue
Block a user