### What this PR does / why we need it? This PR ports all the deepseek graph mode code and mtp code from v0.7.3 to the main branch --------- Signed-off-by: SidaoY <1024863041@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: mengwei805 <mengwei25@huawei.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: q00832892 <qiaoyang19@huawei.com> Signed-off-by: ganyi <pleaplusone.gy@gmail.com> Co-authored-by: SidaoY <1024863041@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: libaokui <libaokui@huawei.com>
129 lines
4.0 KiB
Python
129 lines
4.0 KiB
Python
"""
|
|
This file demonstrates the example usage of disaggregated prefilling
|
|
We will launch 2 vllm instances (NPU 0,1 for prefill and NPU 2,3 for decode),
|
|
and then transfer the KV cache between them.
|
|
"""
|
|
import multiprocessing as mp
|
|
import os
|
|
import time
|
|
from multiprocessing import Event, Process
|
|
|
|
|
|
def clean_up():
|
|
import gc
|
|
|
|
import torch
|
|
from vllm.distributed.parallel_state import (
|
|
destroy_distributed_environment, destroy_model_parallel)
|
|
destroy_model_parallel()
|
|
destroy_distributed_environment()
|
|
gc.collect()
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
def run_prefill(prefill_done, process_close):
|
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import KVTransferConfig
|
|
|
|
prompts = [
|
|
"Hello, how are you today?", "Hi, what is your name?",
|
|
"Tell me a very long story.", "what is your favourite book?"
|
|
]
|
|
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
|
|
|
ktc = KVTransferConfig.from_cli(
|
|
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
|
|
)
|
|
|
|
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
|
|
# memory. You may need to adjust the value to fit your GPU.
|
|
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
|
kv_transfer_config=ktc,
|
|
max_model_len=2000,
|
|
gpu_memory_utilization=0.8,
|
|
tensor_parallel_size=2)
|
|
|
|
llm.generate(prompts, sampling_params)
|
|
print("Prefill node is finished.")
|
|
prefill_done.set()
|
|
|
|
# To keep the prefill node running in case the decode node is not done;
|
|
# otherwise, the script might exit prematurely, causing incomplete decoding.
|
|
try:
|
|
while not process_close.is_set():
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
print("Script stopped by user.")
|
|
finally:
|
|
print("Cleanup prefill resources")
|
|
del llm
|
|
clean_up()
|
|
|
|
|
|
def run_decode(prefill_done):
|
|
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
|
|
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import KVTransferConfig
|
|
|
|
prompts = [
|
|
"Hello, how are you today?", "Hi, what is your name?",
|
|
"Tell me a very long story.", "what is your favourite book?"
|
|
]
|
|
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
|
|
|
ktc = KVTransferConfig.from_cli(
|
|
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
|
|
)
|
|
|
|
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
|
kv_transfer_config=ktc,
|
|
max_model_len=2000,
|
|
gpu_memory_utilization=0.8,
|
|
tensor_parallel_size=2)
|
|
|
|
# Wait for the producer to start the comsumer
|
|
print("Waiting for prefill node to finish...")
|
|
prefill_done.wait()
|
|
|
|
# At this point when the prefill_done is set, the kv-cache should have been
|
|
# transferred to this decode node, so we can start decoding.
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
for output in outputs:
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
del llm
|
|
clean_up()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mp.get_context('spawn')
|
|
|
|
prefill_done = Event()
|
|
process_close = Event()
|
|
prefill_process = Process(target=run_prefill,
|
|
args=(
|
|
prefill_done,
|
|
process_close,
|
|
))
|
|
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
|
|
|
# Start prefill node
|
|
prefill_process.start()
|
|
|
|
# Start decode node
|
|
decode_process.start()
|
|
|
|
# Terminate the prefill node when decode is finished
|
|
decode_process.join()
|
|
|
|
# Terminate prefill process
|
|
process_close.set()
|
|
prefill_process.join()
|
|
prefill_process.terminate()
|
|
print("All process done!")
|