Files
xc-llm-ascend/examples/offline_inference_npu_long_seq.py
wangxiyuan f10acddb78 drop ascend scheduler (#4498)
Ascend scheduler was added for non chunk prefill case before, since that
the npu ops didn't work well with chunked prefill.

Now the ops with chunked prefill work better, it's time to remove the
ascend scheduler to use vLLM default scheduler.

- vLLM version: v0.11.2

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-11-29 16:18:34 +08:00

59 lines
1.9 KiB
Python

import os
import time
import argparse
from vllm import LLM, SamplingParams
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL"] = "1"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input_len', type=int, default=1024)
parser.add_argument('--output_len', type=int, default=128)
parser.add_argument('--bs', type=int, default=1)
parser.add_argument('--model_path', type=str, default="deepseek-ai/DeepSeek-V2-Lite")
parser.add_argument('--tp', type=int, default=2)
parser.add_argument('--pcp', type=int, default=2)
parser.add_argument('--dcp', type=int, default=1)
parser.add_argument('--iter_times', type=int, default=1)
args = parser.parse_args()
prompts = [
"The capital of France is",
"Hello, my name is Tom, I am",
"The president of United States is",
"AI future is"
]
sampling_params = SamplingParams(temperature = 0.8, top_p = 0.95, max_tokens=args.output_len)
llm = LLM(
model=args.model_path,
trust_remote_code=True,
enforce_eager=True,
tensor_parallel_size=args.tp,
prefill_context_parallel_size=args.pcp,
decode_context_parallel_size=args.dcp,
enable_prefix_caching=False,
enable_expert_parallel=True,
enable_chunked_prefill=False,
max_num_batched_tokens=2048,
max_model_len=1024,
max_num_seqs=1,
block_size=128,
gpu_memory_utilization=0.9
)
t0 = time.time()
for _ in range(args.iter_times):
outputs = llm.generate(prompts, sampling_params)
t1 = time.time()
print(f"TTFT: {(t1 - t0) * 1000 / (args.iter_times * args.bs)} ms")
for i, output in enumerate(outputs):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"req_num: {i}\nGenerated text: {generated_text!r}")