feat: support compile torchair graph while warming up (#839)

### What this PR does / why we need it?
feat: support compile torchair graph while warming up

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-31 06:03:03 +08:00
committed by GitHub
parent d9fb027068
commit 507ae627ca
7 changed files with 242 additions and 234 deletions

View File

@@ -108,8 +108,7 @@ jobs:
run: | run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
# AscendScheduler doesn't work, fix it later pytest -sv tests/singlecard/test_scheduler.py
# pytest -sv tests/singlecard/tets_schedule.py
# guided decoding doesn't work, fix it later # guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py # pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py
@@ -124,8 +123,7 @@ jobs:
run: | run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
# AscendScheduler doesn't work, fix it later pytest -sv tests/singlecard/test_scheduler.py
# pytest -sv tests/singlecard/tets_schedule.py
# guided decoding doesn't work, fix it later # guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py # pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py pytest -sv tests/singlecard/ --ignore=tests/singlecard/test_offline_inference.py --ignore=tests/singlecard/test_scheduler.py --ignore=tests/singlecard/test_guided_decoding.py

View File

@@ -31,6 +31,7 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
@@ -83,11 +84,10 @@ def create_scheduler(
cache_dtype="auto", cache_dtype="auto",
**kwargs_cache, **kwargs_cache,
) )
vllm_config = VllmConfig( vllm_config = VllmConfig(scheduler_config=scheduler_config,
scheduler_config=scheduler_config, model_config=model_config,
model_config=model_config, cache_config=cache_config)
cache_config=cache_config,
)
kv_cache_config = KVCacheConfig( kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests num_blocks=10000, # A large number of blocks to hold all requests
tensors={}, tensors={},
@@ -98,10 +98,7 @@ def create_scheduler(
) )
cache_config.num_gpu_blocks = 10000 cache_config.num_gpu_blocks = 10000
return AscendScheduler( return AscendScheduler(
scheduler_config, vllm_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
log_stats=True, log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config), structured_output_manager=StructuredOutputManager(vllm_config),
@@ -126,17 +123,27 @@ def create_requests(num_requests: int,
else: else:
mm_position = None mm_position = None
mm_inputs = None mm_inputs = None
request = Request( if vllm_version_is("0.9.0"):
request_id=f"{i}", request = Request(
prompt=None, request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
multi_modal_inputs=mm_inputs, multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, arrival_time=0,
arrival_time=0, eos_token_id=EOS_TOKEN_ID,
) )
else:
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
)
requests.append(request) requests.append(request)
return requests return requests
@@ -225,12 +232,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 1, requests[0].request_id: 1,
requests[1].request_id: 2 requests[1].request_id: 2
}, },
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
@@ -275,12 +279,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3, requests[0].request_id: 3,
requests[1].request_id: 2 requests[1].request_id: 2
}, },
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=5, total_num_scheduled_tokens=5,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
@@ -323,12 +324,9 @@ def test_stop_via_update_from_output():
requests[0].request_id: 3, requests[0].request_id: 3,
requests[1].request_id: 1 requests[1].request_id: 1
}, },
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=4, total_num_scheduled_tokens=4,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
@@ -369,11 +367,9 @@ def test_stop_via_update_from_output():
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3}, num_scheduled_tokens={requests[0].request_id: 3},
scheduled_spec_decode_tokens={},
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],

View File

@@ -241,7 +241,44 @@ class AscendMLAMetadataBuilder:
max_blocks] = block_tables[:num_seqs, : max_blocks] = block_tables[:num_seqs, :
max_blocks] max_blocks]
return graph_block_tables return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=1,
num_decode_tokens=1,
num_prefills=0,
attn_mask=self.runner.attn_mask,
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
)
def build(self, def build(self,
num_reqs: int, num_reqs: int,
@@ -324,7 +361,7 @@ class AscendMLAMetadataBuilder:
block_table = torch.cat([block_table, block_table_padding], block_table = torch.cat([block_table, block_table_padding],
dim=0) dim=0)
block_table = self._get_graph_runner_block_tables( block_table = self._get_graph_runner_block_tables(
num_seqs, block_table) num_seqs + graph_pad_size, block_table)
padding_0 = torch.zeros(graph_pad_size, padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype, dtype=input_positions.dtype,
device=input_positions.device) device=input_positions.device)

View File

@@ -15,7 +15,7 @@
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from collections import deque from collections import deque
from typing import Iterable, Optional, Union from typing import Iterable, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import logger from vllm.logger import logger
@@ -23,12 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
@@ -130,14 +128,15 @@ class AscendScheduler(Scheduler):
assert num_new_tokens > 0 assert num_new_tokens > 0
watermark = getattr(self.scheduler_config, "watermark", 0.01) watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill( if not self._check_watermark_for_prefill(request, num_new_tokens,
request, num_new_tokens, computed_blocks, watermark): computed_blocks.blocks,
watermark):
# Scheduling would exceed watermark, skip. # Scheduling would exceed watermark, skip.
skip_cur_request() skip_cur_request()
continue continue
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks) request, num_new_tokens, new_computed_blocks=computed_blocks)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
@@ -155,9 +154,8 @@ class AscendScheduler(Scheduler):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in computed_blocks + new_blocks self.kv_cache_manager.get_block_ids(request.request_id))
]
# Update request info. # Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
@@ -215,9 +213,8 @@ class AscendScheduler(Scheduler):
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id) self.scheduled_req_ids.add(request.request_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in new_blocks new_blocks.get_block_ids())
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
@@ -326,7 +323,8 @@ class AscendScheduler(Scheduler):
len(computed_blocks) * self.block_size) len(computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, num_required_blocks = cdiv(num_new_tokens + num_computed_tokens,
self.block_size) self.block_size)
req_blocks = self.kv_cache_manager.req_to_blocks[request.request_id] req_blocks = self.kv_cache_manager.single_type_manager.req_to_blocks[
request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(computed_blocks)) len(computed_blocks))
num_evictable_computed_blocks = sum(1 for blk in computed_blocks num_evictable_computed_blocks = sum(1 for blk in computed_blocks
@@ -365,41 +363,22 @@ class AscendScheduler(Scheduler):
For example, the API server can abort a request when the client For example, the API server can abort a request when the client
disconnects. disconnects.
""" """
assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str):
request_ids = (request_ids, )
else:
request_ids = set(request_ids)
for req_id in request_ids: for req_id in request_ids:
request = self.requests.get(req_id) request = self.requests.get(req_id)
if request is None: if request is None:
# Invalid request ID. # Invalid request ID.
continue continue
if request.status == RequestStatus.RUNNING: if request.status == RequestStatus.RUNNING:
self.running.remove(request)
self.scheduled_req_ids.discard(request.request_id) self.scheduled_req_ids.discard(request.request_id)
else: super().finish_requests(request_ids, finished_status)
self.waiting.remove(request)
request.status = finished_status
self._free_request(request)
def update_from_output( def update_from_output(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs: ) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop. # expensive operations inside the loop.
@@ -408,121 +387,8 @@ class AscendScheduler(Scheduler):
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
# The request was not scheduled in this step. # The request was not scheduled in this step.
new_running.append(request)
continue continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids,
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(req_id) self.scheduled_req_ids.remove(req_id)
if not stopped:
new_running.append(request)
# Return the cached request data to the queue so they can be reused. return super().update_from_output(scheduler_output,
for req_data in scheduler_output.scheduled_cached_reqs: model_runner_output)
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
engine_core_outputs = EngineCoreOutputs(
outputs=outputs,
scheduler_stats=self.make_stats(spec_decoding_stats),
)
if self.include_finished_set:
#TODO currently sending duplicates here, improve this
engine_core_outputs.finished_requests = (
scheduler_output.finished_req_ids | self.finished_req_ids)
return engine_core_outputs

View File

@@ -66,6 +66,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: os.getenv("C_COMPILER", None), lambda: os.getenv("C_COMPILER", None),
"VLLM_VERSION": "VLLM_VERSION":
lambda: os.getenv("VLLM_VERSION", None), lambda: os.getenv("VLLM_VERSION", None),
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@@ -36,9 +36,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce) get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@@ -211,8 +212,12 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group self.tp_rank = get_tp_group().rank_in_group
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(
attn_metadata = get_forward_context().attn_metadata self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens # when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank. # to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not. # TODO: need a better flag to indicate whether in profile run or not.
@@ -547,7 +552,11 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
else:
hidden_states = self.mlp(hidden_states)
if isinstance( if isinstance(
self.mlp, self.mlp,

View File

@@ -28,10 +28,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
import torch._dynamo.cache_size
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
@@ -70,7 +72,9 @@ if TYPE_CHECKING:
else: else:
xgr = LazyLoader("xgr", globals(), "xgrammar") xgr = LazyLoader("xgr", globals(), "xgrammar")
import vllm.envs as envs import vllm.envs as envs_vllm
import vllm_ascend.envs as envs_ascend
@dataclass @dataclass
@@ -321,6 +325,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.sampler = Sampler() self.sampler = Sampler()
self.enable_torchair_graph_mode = False self.enable_torchair_graph_mode = False
self.use_cached_npu_graph = False self.use_cached_npu_graph = False
self.torchair_graph_batch_sizes = []
additional_config = vllm_config.additional_config additional_config = vllm_config.additional_config
if additional_config: if additional_config:
self.enable_torchair_graph_mode = additional_config.get( self.enable_torchair_graph_mode = additional_config.get(
@@ -328,6 +333,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
False) and self.vllm_config.model_config.use_mla False) and self.vllm_config.model_config.use_mla
self.use_cached_npu_graph = additional_config.get( self.use_cached_npu_graph = additional_config.get(
"use_cached_npu_graph", False) "use_cached_npu_graph", False)
self.torchair_graph_batch_sizes = additional_config.get(
"torchair_graph_batch_sizes", [])
if not isinstance(self.torchair_graph_batch_sizes, list):
logger.warning("torchair_graph_batch_sizes must be list[int]")
self.torchair_graph_batch_sizes = []
if len(self.torchair_graph_batch_sizes
) == 0 and additional_config.get(
"torchair_graph_batch_sizes_init", False):
self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0:
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
if envs_ascend.VLLM_ENABLE_MC2:
self.torchair_graph_batch_sizes = [
self.scheduler_config.max_num_seqs
]
else:
self.torchair_graph_batch_sizes = [
1, self.scheduler_config.max_num_seqs
]
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler """Update the cached states and the persistent batch with the scheduler
@@ -618,7 +649,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc, seq_lens=seq_lens) query_start_loc=query_start_loc, seq_lens=seq_lens)
# Add graph_pad_size here # Add graph_pad_size here
if self.enable_torchair_graph_mode: if self.enable_torchair_graph_mode:
graph_pad_size = self.scheduler_config.max_num_seqs - len(seq_lens) batchsize = len(seq_lens)
padded_batch_size = self.select_torchair_padded_batchsize(
batchsize)
graph_pad_size = padded_batch_size - batchsize
extra_builder_kwargs['graph_pad_size'] = graph_pad_size extra_builder_kwargs['graph_pad_size'] = graph_pad_size
if self.vllm_config.model_config.use_mla: if self.vllm_config.model_config.use_mla:
@@ -653,11 +687,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
padding = torch.zeros(graph_pad_size, input_ids = self.input_ids[:padded_batch_size]
dtype=input_ids.dtype, positions = self.positions[:padded_batch_size]
device=input_ids.device)
input_ids = torch.cat([input_ids, padding])
positions = torch.cat([positions, padding])
# Run forward pass # Run forward pass
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
@@ -668,15 +699,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model_kwargs["kv_caches"] = self.kv_caches model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata model_kwargs["attn_metadata"] = attn_metadata
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly: if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
if isinstance(kv, tuple):
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model( hidden_states = self.compile_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
@@ -1068,7 +1090,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode() @torch.inference_mode()
def _dummy_run(self, num_tokens: int) -> torch.Tensor: def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.
@@ -1112,12 +1139,38 @@ class NPUModelRunner(LoRAModelRunnerMixin):
}) })
with set_forward_context(None, self.vllm_config): with set_forward_context(None, self.vllm_config):
hidden_states = model( if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
input_ids=input_ids, attn_metadata = self.attn_metadata_builder.build_dummy(
positions=positions, num_reqs=num_tokens, num_actual_tokens=1)
intermediate_tensors=intermediate_tensors, # Only mark static while compiling
inputs_embeds=inputs_embeds) if is_compile:
return hidden_states torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
kv, tuple), "kv_cache must be a tuple"
torch._dynamo.mark_static(kv[0])
torch._dynamo.mark_static(kv[1])
hidden_states = self.compile_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
)
else:
hidden_states = model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
@@ -1192,13 +1245,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.compile_model = torch.compile( self.compile_model = torch.compile(
self.model, self.model,
dynamic=True, dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend) backend=npu_backend)
else: else:
self.compile_model = torchair.inference.cache_compile( self.compile_model = torchair.inference.cache_compile(
self.model.forward, self.model.forward,
dynamic=True, dynamic=True,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config, config=config,
ge_cache=False) ge_cache=False)
@@ -1316,25 +1369,49 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_cache_spec return kv_cache_spec
def capture_model(self) -> None: def capture_model(self) -> None:
if not self.use_aclgraph:
logger.warning(
"Skipping NPU graph capture. Please add "
"-O %s to use NPU graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter() start_time = time.perf_counter()
start_free_npu_memory = torch.npu.mem_get_info()[0] start_free_npu_memory = torch.npu.mem_get_info()[0]
# TODO(NeverRaR): Calling graph_capture(device=self.device) in
# Trigger ACL graph capture for specific shapes. # torchair graph capture can cause some issues, so now we just
# Capture the large shapes first so that the smaller shapes # temporarily split the codepath for the two different graph patterns.
# can reuse the memory pool allocated for the large shapes. if self.enable_torchair_graph_mode:
with graph_capture(device=self.device): torchair_graph_batch_sizes = self.torchair_graph_batch_sizes
for num_tokens in reversed(self.aclgraph_batch_sizes): graph_num = len(torchair_graph_batch_sizes)
logger.info(
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
0.5 * graph_num, 1.5 * graph_num)
attn_state = AscendAttentionState.DecodeOnly
# Trigger torchair graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for idx, num_tokens in enumerate(
reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
attn_state=attn_state)
self._dummy_run(num_tokens,
is_compile=True,
attn_state=attn_state)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, graph_num)
elif self.use_aclgraph:
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens) self._dummy_run(num_tokens)
self._dummy_run(num_tokens) else:
logger.warning(
"Skipping NPU graph capture. Please add -O %s to use ACL graphs. "
"Or add --additional_config={'enable_graph_mode': True} to use torchair graphs",
CompilationLevel.PIECEWISE)
return
end_time = time.perf_counter() end_time = time.perf_counter()
end_free_npu_memory = torch.npu.mem_get_info()[0] end_free_npu_memory = torch.npu.mem_get_info()[0]
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
@@ -1443,4 +1520,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
spec_token_ids = draft_token_ids.tolist() spec_token_ids = draft_token_ids.tolist()
return spec_token_ids return spec_token_ids
def init_torchair_graph_batch_sizes(self):
tp_size = get_tensor_model_parallel_world_size()
batch_size_step = 8
largest_batch_size = 1
if envs_ascend.VLLM_ENABLE_MC2:
batch_size_step = max(batch_size_step, tp_size)
largest_batch_size = batch_size_step
while (largest_batch_size < 8):
self.torchair_graph_batch_sizes.append(largest_batch_size)
largest_batch_size *= 2
while (largest_batch_size <= self.scheduler_config.max_num_seqs):
self.torchair_graph_batch_sizes.append(largest_batch_size)
largest_batch_size += batch_size_step
def select_torchair_padded_batchsize(self, batchsize: int):
selected_batchsize = self.max_num_reqs
for padded_batchsize in self.torchair_graph_batch_sizes:
if batchsize <= padded_batchsize < selected_batchsize:
selected_batchsize = padded_batchsize
return selected_batchsize