Files
xc-llm-ascend/vllm_ascend/torchair/utils.py
Nicholas Tao 7bec1a9b9c qwen3_moe/qwen25 support torchair graph (#2403)
### What this PR does / why we need it?
Added support for the TorchAir graph mode in qwen3_moe and qwen2.5
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
llm = LLM(
    model=model,
    tensor_parallel_size=GPUs_per_dp_rank,
    enforce_eager=False,
    enable_expert_parallel=True,
    max_model_len=4096,
    max_num_seqs=16,
    trust_remote_code=trust_remote_code,
    gpu_memory_utilization=0.4,
    additional_config={
             "torchair_graph_config": {
                 "enabled": True,
                 "use_cached_graph": False,
                 "graph_batch_sizes_init": False,
                 "graph_batch_sizes": [16]
             },
             "ascend_scheduler_config": {
                 "enabled": True,
                 "chunked_prefill_enabled":True,
             },
             "refresh": True,
    },
)
```

- vLLM version: v0.10.0
- vLLM main:
b87cb97a53

Signed-off-by: taoyuxiang <oui.nicholas.tao@gmail.com>
2025-08-20 11:23:50 +08:00

153 lines
4.6 KiB
Python

import fcntl
import os
import shutil
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
import torch
try:
# Recent release of torchair has moved these ops to `.scope`.
from torchair.scope import npu_stream_switch as _npu_stream_switch
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
TORCHAIR_CACHE_PATH_NAME = ".torchair_cache"
TORCHAIR_CACHE_DIR = os.getenv(
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
@dataclass
class TorchairCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
"""Total number of tokens in batch"""
decode_token_per_req: int
actual_seq_lengths_q: list[int]
attn_mask: torch.Tensor = None
spec_attn_mask: torch.Tensor = None
graph_pad_size: int = -1
@contextmanager
def _file_lock(file_descriptor, lock_type):
fcntl.flock(file_descriptor, lock_type)
try:
yield
finally:
fcntl.flock(file_descriptor, fcntl.LOCK_UN)
def _get_torchair_current_work_dir(file_name=None):
if file_name is None:
return TORCHAIR_CACHE_DIR
return os.path.join(TORCHAIR_CACHE_DIR, file_name)
def check_torchair_cache_exist():
res = False
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
file_list = os.listdir(torch_air_abs_path)
if len(file_list) != 0:
res = True
return res
def check_kv_cache_bytes_cache_exist():
res = False
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
if os.path.exists(kv_cache_bytes_cache_abs_path):
file_list = os.listdir(kv_cache_bytes_cache_abs_path)
if len(file_list) != 0:
res = True
return res
def read_kv_cache_bytes_from_file(rank) -> int:
kv_cache_bytes = -1
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "r", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_SH):
kv_cache_bytes = int(f.readline())
return kv_cache_bytes
def write_kv_cache_bytes_to_file(rank, kv_cache_bytes):
kv_cache_bytes_cache_abs_path = _get_torchair_current_work_dir(
KV_CACHE_BYTES_CACHE_PATH_NAME)
os.makedirs(kv_cache_bytes_cache_abs_path, exist_ok=True)
kv_cache_bytes_file = os.path.join(
kv_cache_bytes_cache_abs_path,
f"{rank}_{KV_CACHE_BYTES_CACHE_FILE_NAME}")
with open(kv_cache_bytes_file, "w", encoding="utf-8") as f:
with _file_lock(f, fcntl.LOCK_EX):
f.write(f"{kv_cache_bytes}")
def delete_torchair_cache_file():
torch_air_abs_path = _get_torchair_current_work_dir()
if os.path.exists(torch_air_abs_path):
shutil.rmtree(torch_air_abs_path)
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,
enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self
def register_torchair_model():
from vllm import ModelRegistry
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
)
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
ModelRegistry.register_model(
"Qwen3ForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM")