feat: support compile torchair graph while warming up (#839)
### What this PR does / why we need it? feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
6
.github/workflows/vllm_ascend_test.yaml
vendored
6
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -108,8 +108,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
||||||
# AscendScheduler doesn't work, fix it later
|
pytest -sv tests/singlecard/test_scheduler.py
|
||||||
# pytest -sv tests/singlecard/tets_schedule.py
|
|
||||||
# guided decoding doesn't work, fix it later
|
# guided decoding doesn't work, fix it later
|
||||||
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
||||||
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
|
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
|
||||||
@@ -124,8 +123,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
||||||
# AscendScheduler doesn't work, fix it later
|
pytest -sv tests/singlecard/test_scheduler.py
|
||||||
# pytest -sv tests/singlecard/tets_schedule.py
|
|
||||||
# guided decoding doesn't work, fix it later
|
# guided decoding doesn't work, fix it later
|
||||||
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
||||||
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
|
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
from vllm_ascend.core.scheduler import AscendScheduler
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
@@ -83,11 +84,10 @@ def create_scheduler(
|
|||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
**kwargs_cache,
|
**kwargs_cache,
|
||||||
)
|
)
|
||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(scheduler_config=scheduler_config,
|
||||||
scheduler_config=scheduler_config,
|
model_config=model_config,
|
||||||
model_config=model_config,
|
cache_config=cache_config)
|
||||||
cache_config=cache_config,
|
|
||||||
)
|
|
||||||
kv_cache_config = KVCacheConfig(
|
kv_cache_config = KVCacheConfig(
|
||||||
num_blocks=10000, # A large number of blocks to hold all requests
|
num_blocks=10000, # A large number of blocks to hold all requests
|
||||||
tensors={},
|
tensors={},
|
||||||
@@ -98,10 +98,7 @@ def create_scheduler(
|
|||||||
)
|
)
|
||||||
cache_config.num_gpu_blocks = 10000
|
cache_config.num_gpu_blocks = 10000
|
||||||
return AscendScheduler(
|
return AscendScheduler(
|
||||||
scheduler_config,
|
vllm_config,
|
||||||
model_config,
|
|
||||||
cache_config,
|
|
||||||
lora_config=None,
|
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
@@ -126,17 +123,27 @@ def create_requests(num_requests: int,
|
|||||||
else:
|
else:
|
||||||
mm_position = None
|
mm_position = None
|
||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
request = Request(
|
if vllm_version_is("0.9.0"):
|
||||||
request_id=f"{i}",
|
request = Request(
|
||||||
prompt=None,
|
request_id=f"{i}",
|
||||||
prompt_token_ids=[i] * num_tokens,
|
prompt_token_ids=[i] * num_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
multi_modal_inputs=mm_inputs,
|
multi_modal_inputs=mm_inputs,
|
||||||
multi_modal_placeholders=mm_position,
|
multi_modal_placeholders=mm_position,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
arrival_time=0,
|
||||||
arrival_time=0,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt_token_ids=[i] * num_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
)
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
@@ -225,12 +232,9 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: 1,
|
requests[0].request_id: 1,
|
||||||
requests[1].request_id: 2
|
requests[1].request_id: 2
|
||||||
},
|
},
|
||||||
|
scheduled_spec_decode_tokens={},
|
||||||
total_num_scheduled_tokens=3,
|
total_num_scheduled_tokens=3,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
scheduled_spec_decode_tokens={
|
|
||||||
requests[0].request_id: [],
|
|
||||||
requests[1].request_id: [10]
|
|
||||||
},
|
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
@@ -275,12 +279,9 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: 3,
|
requests[0].request_id: 3,
|
||||||
requests[1].request_id: 2
|
requests[1].request_id: 2
|
||||||
},
|
},
|
||||||
|
scheduled_spec_decode_tokens={},
|
||||||
total_num_scheduled_tokens=5,
|
total_num_scheduled_tokens=5,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
scheduled_spec_decode_tokens={
|
|
||||||
requests[0].request_id: [10, 42],
|
|
||||||
requests[1].request_id: [13]
|
|
||||||
},
|
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
@@ -323,12 +324,9 @@ def test_stop_via_update_from_output():
|
|||||||
requests[0].request_id: 3,
|
requests[0].request_id: 3,
|
||||||
requests[1].request_id: 1
|
requests[1].request_id: 1
|
||||||
},
|
},
|
||||||
|
scheduled_spec_decode_tokens={},
|
||||||
total_num_scheduled_tokens=4,
|
total_num_scheduled_tokens=4,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
scheduled_spec_decode_tokens={
|
|
||||||
requests[0].request_id: [10, 11],
|
|
||||||
requests[1].request_id: []
|
|
||||||
},
|
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
@@ -369,11 +367,9 @@ def test_stop_via_update_from_output():
|
|||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
num_scheduled_tokens={requests[0].request_id: 3},
|
num_scheduled_tokens={requests[0].request_id: 3},
|
||||||
|
scheduled_spec_decode_tokens={},
|
||||||
total_num_scheduled_tokens=3,
|
total_num_scheduled_tokens=3,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
scheduled_spec_decode_tokens={
|
|
||||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
|
||||||
},
|
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
|||||||
@@ -241,7 +241,44 @@ class AscendMLAMetadataBuilder:
|
|||||||
max_blocks] = block_tables[:num_seqs, :
|
max_blocks] = block_tables[:num_seqs, :
|
||||||
max_blocks]
|
max_blocks]
|
||||||
|
|
||||||
return graph_block_tables
|
return graph_block_tables[:num_seqs, :max_blocks]
|
||||||
|
|
||||||
|
def build_dummy(self, num_reqs: int,
|
||||||
|
num_actual_tokens: int) -> AscendMLAMetadata:
|
||||||
|
device = self.runner.device
|
||||||
|
_, max_blocks = self.runner.graph_block_tables.shape
|
||||||
|
block_table = torch.zeros((num_reqs, max_blocks),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
block_table = self._get_graph_runner_block_tables(
|
||||||
|
num_reqs, block_table)
|
||||||
|
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||||
|
input_positions = torch.zeros(num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device).long()
|
||||||
|
slot_mapping = torch.full((num_reqs, ),
|
||||||
|
PAD_SLOT_ID,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
decode_metadata = AscendMLADecodeMetadata(
|
||||||
|
input_positions=input_positions,
|
||||||
|
block_table=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_list=seq_lens.tolist(),
|
||||||
|
max_seq_lens=1)
|
||||||
|
return self.metadata_cls( # type: ignore
|
||||||
|
num_input_tokens=num_actual_tokens,
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
head_dim=self.runner.model_config.get_head_size(),
|
||||||
|
num_decodes=1,
|
||||||
|
num_decode_tokens=1,
|
||||||
|
num_prefills=0,
|
||||||
|
attn_mask=self.runner.attn_mask,
|
||||||
|
attn_state=AscendAttentionState.DecodeOnly,
|
||||||
|
prefill=None,
|
||||||
|
decode=decode_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
num_reqs: int,
|
num_reqs: int,
|
||||||
@@ -324,7 +361,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
block_table = torch.cat([block_table, block_table_padding],
|
block_table = torch.cat([block_table, block_table_padding],
|
||||||
dim=0)
|
dim=0)
|
||||||
block_table = self._get_graph_runner_block_tables(
|
block_table = self._get_graph_runner_block_tables(
|
||||||
num_seqs, block_table)
|
num_seqs + graph_pad_size, block_table)
|
||||||
padding_0 = torch.zeros(graph_pad_size,
|
padding_0 = torch.zeros(graph_pad_size,
|
||||||
dtype=input_positions.dtype,
|
dtype=input_positions.dtype,
|
||||||
device=input_positions.device)
|
device=input_positions.device)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Iterable, Optional, Union
|
from typing import Iterable, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -23,12 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.core.sched.utils import check_stop
|
from vllm.v1.engine import EngineCoreOutputs
|
||||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
|
||||||
@@ -130,14 +128,15 @@ class AscendScheduler(Scheduler):
|
|||||||
|
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||||
if not self._check_watermark_for_prefill(
|
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||||
request, num_new_tokens, computed_blocks, watermark):
|
computed_blocks.blocks,
|
||||||
|
watermark):
|
||||||
# Scheduling would exceed watermark, skip.
|
# Scheduling would exceed watermark, skip.
|
||||||
skip_cur_request()
|
skip_cur_request()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens, computed_blocks)
|
request, num_new_tokens, new_computed_blocks=computed_blocks)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
break
|
break
|
||||||
@@ -155,9 +154,8 @@ class AscendScheduler(Scheduler):
|
|||||||
|
|
||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
req_to_new_block_ids[request.request_id] = [
|
req_to_new_block_ids[request.request_id] = (
|
||||||
b.block_id for b in computed_blocks + new_blocks
|
self.kv_cache_manager.get_block_ids(request.request_id))
|
||||||
]
|
|
||||||
# Update request info.
|
# Update request info.
|
||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
@@ -215,9 +213,8 @@ class AscendScheduler(Scheduler):
|
|||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
self.scheduled_req_ids.add(request.request_id)
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
req_to_new_block_ids[request.request_id] = [
|
req_to_new_block_ids[request.request_id] = (
|
||||||
b.block_id for b in new_blocks
|
new_blocks.get_block_ids())
|
||||||
]
|
|
||||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
req_index += 1
|
req_index += 1
|
||||||
@@ -326,7 +323,8 @@ class AscendScheduler(Scheduler):
|
|||||||
len(computed_blocks) * self.block_size)
|
len(computed_blocks) * self.block_size)
|
||||||
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
|
||||||
self.block_size)
|
self.block_size)
|
||||||
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id]
|
req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||||
|
request.request_id]
|
||||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||||
len(computed_blocks))
|
len(computed_blocks))
|
||||||
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
||||||
@@ -365,41 +363,22 @@ class AscendScheduler(Scheduler):
|
|||||||
For example, the API server can abort a request when the client
|
For example, the API server can abort a request when the client
|
||||||
disconnects.
|
disconnects.
|
||||||
"""
|
"""
|
||||||
assert RequestStatus.is_finished(finished_status)
|
|
||||||
if isinstance(request_ids, str):
|
|
||||||
request_ids = (request_ids, )
|
|
||||||
else:
|
|
||||||
request_ids = set(request_ids)
|
|
||||||
|
|
||||||
for req_id in request_ids:
|
for req_id in request_ids:
|
||||||
request = self.requests.get(req_id)
|
request = self.requests.get(req_id)
|
||||||
if request is None:
|
if request is None:
|
||||||
# Invalid request ID.
|
# Invalid request ID.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if request.status == RequestStatus.RUNNING:
|
if request.status == RequestStatus.RUNNING:
|
||||||
self.running.remove(request)
|
|
||||||
self.scheduled_req_ids.discard(request.request_id)
|
self.scheduled_req_ids.discard(request.request_id)
|
||||||
else:
|
super().finish_requests(request_ids, finished_status)
|
||||||
self.waiting.remove(request)
|
|
||||||
request.status = finished_status
|
|
||||||
self._free_request(request)
|
|
||||||
|
|
||||||
def update_from_output(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
model_runner_output: ModelRunnerOutput,
|
model_runner_output: ModelRunnerOutput,
|
||||||
) -> EngineCoreOutputs:
|
) -> EngineCoreOutputs:
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
|
||||||
spec_token_ids = model_runner_output.spec_token_ids
|
|
||||||
logprobs = model_runner_output.logprobs
|
|
||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
|
|
||||||
new_running: list[Request] = []
|
|
||||||
outputs: list[EngineCoreOutput] = []
|
|
||||||
spec_decoding_stats: Optional[SpecDecodingStats] = None
|
|
||||||
|
|
||||||
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
|
||||||
# loop can be a performance bottleneck. We should do our best to avoid
|
# loop can be a performance bottleneck. We should do our best to avoid
|
||||||
# expensive operations inside the loop.
|
# expensive operations inside the loop.
|
||||||
@@ -408,121 +387,8 @@ class AscendScheduler(Scheduler):
|
|||||||
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
|
||||||
if num_tokens_scheduled == 0:
|
if num_tokens_scheduled == 0:
|
||||||
# The request was not scheduled in this step.
|
# The request was not scheduled in this step.
|
||||||
new_running.append(request)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
req_index = model_runner_output.req_id_to_index[req_id]
|
|
||||||
generated_token_ids = sampled_token_ids[req_index]
|
|
||||||
|
|
||||||
scheduled_spec_token_ids = (
|
|
||||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
|
||||||
if scheduled_spec_token_ids:
|
|
||||||
# num_computed_tokens represents the number of tokens
|
|
||||||
# processed in the current step, considering scheduled
|
|
||||||
# tokens and rejections. If some tokens are rejected,
|
|
||||||
# num_computed_tokens is decreased by the number of rejected
|
|
||||||
# tokens, where is given by:
|
|
||||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
|
||||||
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
|
|
||||||
len(generated_token_ids))
|
|
||||||
request.num_computed_tokens -= num_tokens_rejected
|
|
||||||
spec_decoding_stats = self.make_spec_decoding_stats(
|
|
||||||
spec_decoding_stats,
|
|
||||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
|
||||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
|
||||||
|
|
||||||
cached_encoder_input_ids = (
|
|
||||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
|
||||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
|
||||||
if cached_encoder_input_ids:
|
|
||||||
for input_id in list(cached_encoder_input_ids):
|
|
||||||
mm_positions = request.mm_positions[input_id]
|
|
||||||
start_pos = mm_positions.offset
|
|
||||||
num_tokens = mm_positions.length
|
|
||||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
|
||||||
# The encoder output is already processed and stored
|
|
||||||
# in the decoder's KV cache.
|
|
||||||
self.encoder_cache_manager.free_encoder_input(
|
|
||||||
request, input_id)
|
|
||||||
|
|
||||||
stopped = False
|
|
||||||
new_logprobs = None
|
|
||||||
new_token_ids = generated_token_ids
|
|
||||||
|
|
||||||
# Append generated tokens and check for stop. Note that if
|
|
||||||
# a request is still being prefilled, we expect the model runner
|
|
||||||
# to return empty token ids for the request.
|
|
||||||
for num_new, output_token_id in enumerate(new_token_ids, 1):
|
|
||||||
request.append_output_token_ids(output_token_id)
|
|
||||||
|
|
||||||
# Check for stop and update request state.
|
|
||||||
# This must be called before we make the EngineCoreOutput.
|
|
||||||
stopped = check_stop(request, self.max_model_len)
|
|
||||||
if stopped:
|
|
||||||
self._free_request(request)
|
|
||||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
|
||||||
break
|
|
||||||
|
|
||||||
# Extract sample logprobs if needed.
|
|
||||||
if request.sampling_params.logprobs is not None and logprobs:
|
|
||||||
# NOTE: once we support N tokens per step (spec decode),
|
|
||||||
# the outer lists can be of length > 1.
|
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
|
||||||
|
|
||||||
if new_token_ids and request.use_structured_output:
|
|
||||||
# NOTE: structured_output_request
|
|
||||||
# should not be None if use_structured_output, we have
|
|
||||||
# check above, so safe to ignore type warning
|
|
||||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
|
||||||
req_id, new_token_ids)
|
|
||||||
|
|
||||||
# Add newly generated spec token ids to the request.
|
|
||||||
if spec_token_ids is not None:
|
|
||||||
if request.use_structured_output:
|
|
||||||
metadata = request.structured_output_request
|
|
||||||
assert metadata is not None and metadata.grammar is not None
|
|
||||||
# Needs to happen after new_token_ids are accepted.
|
|
||||||
request.spec_token_ids = metadata.grammar.validate_tokens(
|
|
||||||
spec_token_ids[req_index])
|
|
||||||
else:
|
|
||||||
request.spec_token_ids = spec_token_ids[req_index]
|
|
||||||
|
|
||||||
# Get prompt logprobs for this request.
|
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
|
||||||
if new_token_ids:
|
|
||||||
# Add EngineCoreOutput for this Request.
|
|
||||||
outputs.append(
|
|
||||||
EngineCoreOutput(
|
|
||||||
request_id=req_id,
|
|
||||||
new_token_ids=new_token_ids,
|
|
||||||
finish_reason=request.get_finished_reason(),
|
|
||||||
new_logprobs=new_logprobs,
|
|
||||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
|
||||||
stop_reason=request.stop_reason,
|
|
||||||
events=request.take_events()))
|
|
||||||
else:
|
|
||||||
# Invariant: EngineCore returns no partial prefill outputs.
|
|
||||||
assert not prompt_logprobs_tensors
|
|
||||||
|
|
||||||
self.scheduled_req_ids.remove(req_id)
|
self.scheduled_req_ids.remove(req_id)
|
||||||
if not stopped:
|
|
||||||
new_running.append(request)
|
|
||||||
|
|
||||||
# Return the cached request data to the queue so they can be reused.
|
return super().update_from_output(scheduler_output,
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
model_runner_output)
|
||||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
|
||||||
# to _cached_reqs_data will cause a memory leak.
|
|
||||||
if req_data.req_id not in self.finished_req_ids:
|
|
||||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
|
||||||
|
|
||||||
self.running = new_running
|
|
||||||
engine_core_outputs = EngineCoreOutputs(
|
|
||||||
outputs=outputs,
|
|
||||||
scheduler_stats=self.make_stats(spec_decoding_stats),
|
|
||||||
)
|
|
||||||
if self.include_finished_set:
|
|
||||||
#TODO currently sending duplicates here, improve this
|
|
||||||
engine_core_outputs.finished_requests = (
|
|
||||||
scheduler_output.finished_req_ids | self.finished_req_ids)
|
|
||||||
|
|
||||||
return engine_core_outputs
|
|
||||||
|
|||||||
@@ -66,6 +66,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
lambda: os.getenv("C_COMPILER", None),
|
lambda: os.getenv("C_COMPILER", None),
|
||||||
"VLLM_VERSION":
|
"VLLM_VERSION":
|
||||||
lambda: os.getenv("VLLM_VERSION", None),
|
lambda: os.getenv("VLLM_VERSION", None),
|
||||||
|
"VLLM_ASCEND_TRACE_RECOMPILES":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||||
get_current_vllm_config)
|
get_current_vllm_config)
|
||||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
from vllm.distributed import (get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tp_group, tensor_model_parallel_all_reduce)
|
get_tp_group, tensor_model_parallel_all_reduce)
|
||||||
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@@ -211,8 +212,12 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
self.tp_group = get_tp_group().device_group
|
self.tp_group = get_tp_group().device_group
|
||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(
|
||||||
attn_metadata = get_forward_context().attn_metadata
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
if attn_metadata is None:
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
# when profile runs, force experts to load balanced tokens
|
# when profile runs, force experts to load balanced tokens
|
||||||
# to avoid high memory consumption on a single rank.
|
# to avoid high memory consumption on a single rank.
|
||||||
# TODO: need a better flag to indicate whether in profile run or not.
|
# TODO: need a better flag to indicate whether in profile run or not.
|
||||||
@@ -547,7 +552,11 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual)
|
hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
|
||||||
|
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||||
|
hidden_states = self.mlp(hidden_states, attn_metadata)
|
||||||
|
else:
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self.mlp,
|
self.mlp,
|
||||||
|
|||||||
@@ -28,10 +28,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torch._dynamo.cache_size
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
@@ -70,7 +72,9 @@ if TYPE_CHECKING:
|
|||||||
else:
|
else:
|
||||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs_vllm
|
||||||
|
|
||||||
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -321,6 +325,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.enable_torchair_graph_mode = False
|
self.enable_torchair_graph_mode = False
|
||||||
self.use_cached_npu_graph = False
|
self.use_cached_npu_graph = False
|
||||||
|
self.torchair_graph_batch_sizes = []
|
||||||
additional_config = vllm_config.additional_config
|
additional_config = vllm_config.additional_config
|
||||||
if additional_config:
|
if additional_config:
|
||||||
self.enable_torchair_graph_mode = additional_config.get(
|
self.enable_torchair_graph_mode = additional_config.get(
|
||||||
@@ -328,6 +333,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
False) and self.vllm_config.model_config.use_mla
|
False) and self.vllm_config.model_config.use_mla
|
||||||
self.use_cached_npu_graph = additional_config.get(
|
self.use_cached_npu_graph = additional_config.get(
|
||||||
"use_cached_npu_graph", False)
|
"use_cached_npu_graph", False)
|
||||||
|
self.torchair_graph_batch_sizes = additional_config.get(
|
||||||
|
"torchair_graph_batch_sizes", [])
|
||||||
|
if not isinstance(self.torchair_graph_batch_sizes, list):
|
||||||
|
logger.warning("torchair_graph_batch_sizes must be list[int]")
|
||||||
|
self.torchair_graph_batch_sizes = []
|
||||||
|
if len(self.torchair_graph_batch_sizes
|
||||||
|
) == 0 and additional_config.get(
|
||||||
|
"torchair_graph_batch_sizes_init", False):
|
||||||
|
self.init_torchair_graph_batch_sizes()
|
||||||
|
|
||||||
|
if len(self.torchair_graph_batch_sizes) == 0:
|
||||||
|
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
|
||||||
|
if envs_ascend.VLLM_ENABLE_MC2:
|
||||||
|
self.torchair_graph_batch_sizes = [
|
||||||
|
self.scheduler_config.max_num_seqs
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.torchair_graph_batch_sizes = [
|
||||||
|
1, self.scheduler_config.max_num_seqs
|
||||||
|
]
|
||||||
|
|
||||||
|
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||||
|
self.torchair_graph_batch_sizes)
|
||||||
|
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||||
|
torch._logging.set_logs(
|
||||||
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
@@ -618,7 +649,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||||
# Add graph_pad_size here
|
# Add graph_pad_size here
|
||||||
if self.enable_torchair_graph_mode:
|
if self.enable_torchair_graph_mode:
|
||||||
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens)
|
batchsize = len(seq_lens)
|
||||||
|
padded_batch_size = self.select_torchair_padded_batchsize(
|
||||||
|
batchsize)
|
||||||
|
graph_pad_size = padded_batch_size - batchsize
|
||||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla:
|
||||||
@@ -653,11 +687,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
padding = torch.zeros(graph_pad_size,
|
input_ids = self.input_ids[:padded_batch_size]
|
||||||
dtype=input_ids.dtype,
|
positions = self.positions[:padded_batch_size]
|
||||||
device=input_ids.device)
|
|
||||||
input_ids = torch.cat([input_ids, padding])
|
|
||||||
positions = torch.cat([positions, padding])
|
|
||||||
|
|
||||||
# Run forward pass
|
# Run forward pass
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
@@ -668,15 +699,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model_kwargs["kv_caches"] = self.kv_caches
|
model_kwargs["kv_caches"] = self.kv_caches
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||||
torch._dynamo.mark_static(input_ids)
|
|
||||||
torch._dynamo.mark_static(positions)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
|
||||||
for kv in self.kv_caches:
|
|
||||||
if isinstance(kv, tuple):
|
|
||||||
torch._dynamo.mark_static(kv[0])
|
|
||||||
torch._dynamo.mark_static(kv[1])
|
|
||||||
hidden_states = self.compile_model(
|
hidden_states = self.compile_model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@@ -1068,7 +1090,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
|
def _dummy_run(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
is_compile: bool = False,
|
||||||
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
|
||||||
|
) -> torch.Tensor:
|
||||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
# has num_tokens in total.
|
# has num_tokens in total.
|
||||||
@@ -1112,12 +1139,38 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
})
|
})
|
||||||
|
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_forward_context(None, self.vllm_config):
|
||||||
hidden_states = model(
|
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
|
||||||
input_ids=input_ids,
|
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||||
positions=positions,
|
num_reqs=num_tokens, num_actual_tokens=1)
|
||||||
intermediate_tensors=intermediate_tensors,
|
# Only mark static while compiling
|
||||||
inputs_embeds=inputs_embeds)
|
if is_compile:
|
||||||
return hidden_states
|
torch._dynamo.mark_static(input_ids)
|
||||||
|
torch._dynamo.mark_static(positions)
|
||||||
|
torch._dynamo.mark_static(
|
||||||
|
attn_metadata.decode.block_table)
|
||||||
|
torch._dynamo.mark_static(
|
||||||
|
attn_metadata.decode.input_positions)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||||
|
for kv in self.kv_caches:
|
||||||
|
assert isinstance(
|
||||||
|
kv, tuple), "kv_cache must be a tuple"
|
||||||
|
torch._dynamo.mark_static(kv[0])
|
||||||
|
torch._dynamo.mark_static(kv[1])
|
||||||
|
hidden_states = self.compile_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=None,
|
||||||
|
kv_caches=self.kv_caches,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
@@ -1192,13 +1245,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.compile_model = torch.compile(
|
self.compile_model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
dynamic=True,
|
dynamic=True,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
backend=npu_backend)
|
backend=npu_backend)
|
||||||
else:
|
else:
|
||||||
self.compile_model = torchair.inference.cache_compile(
|
self.compile_model = torchair.inference.cache_compile(
|
||||||
self.model.forward,
|
self.model.forward,
|
||||||
dynamic=True,
|
dynamic=True,
|
||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
config=config,
|
config=config,
|
||||||
ge_cache=False)
|
ge_cache=False)
|
||||||
|
|
||||||
@@ -1316,25 +1369,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return kv_cache_spec
|
return kv_cache_spec
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
if not self.use_aclgraph:
|
|
||||||
logger.warning(
|
|
||||||
"Skipping NPU graph capture. Please add "
|
|
||||||
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
|
|
||||||
return
|
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
start_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||||
|
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
|
||||||
# Trigger ACL graph capture for specific shapes.
|
# torchair graph capture can cause some issues, so now we just
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# temporarily split the codepath for the two different graph patterns.
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
if self.enable_torchair_graph_mode:
|
||||||
with graph_capture(device=self.device):
|
torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
|
||||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
graph_num = len(torchair_graph_batch_sizes)
|
||||||
|
logger.info(
|
||||||
|
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||||
|
0.5 * graph_num, 1.5 * graph_num)
|
||||||
|
attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
# Trigger torchair graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
for idx, num_tokens in enumerate(
|
||||||
|
reversed(torchair_graph_batch_sizes)):
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
|
self._dummy_run(num_tokens,
|
||||||
|
is_compile=True,
|
||||||
|
attn_state=attn_state)
|
||||||
|
self._dummy_run(num_tokens,
|
||||||
|
is_compile=True,
|
||||||
|
attn_state=attn_state)
|
||||||
|
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||||
|
num_tokens, idx + 1, graph_num)
|
||||||
|
elif self.use_aclgraph:
|
||||||
|
# Trigger ACL graph capture for specific shapes.
|
||||||
|
# Capture the large shapes first so that the smaller shapes
|
||||||
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
|
with graph_capture(device=self.device):
|
||||||
|
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||||
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
|
cudagraph_num_of_warmups):
|
||||||
|
self._dummy_run(num_tokens)
|
||||||
self._dummy_run(num_tokens)
|
self._dummy_run(num_tokens)
|
||||||
self._dummy_run(num_tokens)
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
|
||||||
|
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
|
||||||
|
CompilationLevel.PIECEWISE)
|
||||||
|
return
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
end_free_npu_memory = torch.npu.mem_get_info()[0]
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
@@ -1443,4 +1520,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
return spec_token_ids
|
return spec_token_ids
|
||||||
|
|
||||||
|
def init_torchair_graph_batch_sizes(self):
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
batch_size_step = 8
|
||||||
|
largest_batch_size = 1
|
||||||
|
|
||||||
|
if envs_ascend.VLLM_ENABLE_MC2:
|
||||||
|
batch_size_step = max(batch_size_step, tp_size)
|
||||||
|
largest_batch_size = batch_size_step
|
||||||
|
while (largest_batch_size < 8):
|
||||||
|
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||||
|
largest_batch_size *= 2
|
||||||
|
|
||||||
|
while (largest_batch_size <= self.scheduler_config.max_num_seqs):
|
||||||
|
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||||
|
largest_batch_size += batch_size_step
|
||||||
|
|
||||||
|
def select_torchair_padded_batchsize(self, batchsize: int):
|
||||||
|
selected_batchsize = self.max_num_reqs
|
||||||
|
for padded_batchsize in self.torchair_graph_batch_sizes:
|
||||||
|
if batchsize <= padded_batchsize < selected_batchsize:
|
||||||
|
selected_batchsize = padded_batchsize
|
||||||
|
return selected_batchsize
|
||||||
|
|||||||
Reference in New Issue
Block a user