feat: support compile torchair graph while warming up (#839)

### What this PR does / why we need it?
feat: support compile torchair graph while warming up

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-31 06:03:03 +08:00
committed by GitHub
parent d9fb027068
commit 507ae627ca
7 changed files with 242 additions and 234 deletions

View File

@@ -31,6 +31,7 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
@@ -83,11 +84,10 @@ def create_scheduler(
cache_dtype="auto",
**kwargs_cache,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
vllm_config = VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
tensors={},
@@ -98,10 +98,7 @@ def create_scheduler(
)
cache_config.num_gpu_blocks = 10000
return AscendScheduler(
scheduler_config,
model_config,
cache_config,
lora_config=None,
vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
@@ -126,17 +123,27 @@ def create_requests(num_requests: int,
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt=None,
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
if vllm_version_is("0.9.0"):
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
arrival_time=0,
eos_token_id=EOS_TOKEN_ID,
)
else:
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request)
return requests
@@ -225,12 +232,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 1,
requests[1].request_id: 2
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
@@ -275,12 +279,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3,
requests[1].request_id: 2
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
@@ -323,12 +324,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3,
requests[1].request_id: 1
},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
@@ -369,11 +367,9 @@ def test_stop_via_update_from_output():
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],