Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm_mlu/vllm_mlu/utils.py
2026-02-04 17:22:39 +08:00

49 lines
1.7 KiB
Python

import torch
from typing import Optional, Union
from vllm import utils
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm_mlu.mlu_hijack_utils import MluHijackObject
STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8
def vllm__utils__get_kv_cache_torch_dtype(
cache_dtype: Optional[Union[str, torch.dtype]],
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
'''
=============================
Modify by vllm_mlu
=============================
@brief: use STR_DTYPE_TO_TORCH_DTYPE to get torch_dtype
'''
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str):
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8":
torch_dtype = torch.uint8
elif cache_dtype == 'int8':
torch_dtype = torch.int8
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(utils,
utils.get_kv_cache_torch_dtype,
vllm__utils__get_kv_cache_torch_dtype)