Files
xc-llm-ascend/vllm_ascend/torchair/torchair_worker.py
wangxiyuan 7265dc090d [2/4][Refactor] Refactor torchair utils (#1892)
There is a lot torchair specified logic in common code. It results hard
code maintenance. We will create a new torchair module to launch
torchair related logic there. I plan to add 4 PR.

1. Refactor worker
2. Refactor utils (this PR)
- simple change that move all torchair related util function to torchair
module
3. Refactor model_runner
4. Refactor attention

- vLLM version: v0.9.2
- vLLM main:
8188196a1c

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-07-21 19:43:30 +08:00

65 lines
2.8 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.torchair.utils import (check_kv_cache_bytes_cache_exist,
check_torchair_cache_exist,
delete_torchair_cache_file,
read_kv_cache_bytes_from_file)
from vllm_ascend.worker.worker_v1 import NPUWorker
class NPUTorchairWorker(NPUWorker):
"""Torchair worker bases on NPUWorker. Only torchair specified code should be added in this class."""
def determine_available_memory(self) -> int:
"""Override determine_available_memory to use cached torchair kv_cache_bytes."""
available_kv_cache_memory = super().determine_available_memory()
if check_torchair_cache_exist() and check_kv_cache_bytes_cache_exist():
old_kv_cache_bytes = read_kv_cache_bytes_from_file(
torch.distributed.get_rank())
if 0 < old_kv_cache_bytes <= available_kv_cache_memory:
logger.info(
f"Use cached torchair kv_cache_bytes: {old_kv_cache_bytes}"
)
self.model_runner.new_kv_cache_bytes = old_kv_cache_bytes
return old_kv_cache_bytes
else:
logger.info(
"Cached torchair kv_cache_bytes is too big, invalidate old torchair_cache"
)
delete_torchair_cache_file()
bytes_floating_tolerance = 1024 * 1024 * envs_ascend.VLLM_ASCEND_KV_CACHE_MEGABYTES_FLOATING_TOLERANCE
available_kv_cache_memory -= bytes_floating_tolerance
logger.info(f"Use new kv_cache_bytes: {available_kv_cache_memory}")
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def _get_max_num_tokens_and_with_prefill(self):
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""
max_num_tokens, with_prefill = super(
)._get_max_num_tokens_and_with_prefill()
if not with_prefill:
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
max_num_tokens)
return max_num_tokens, with_prefill