forked from EngineX-Cambricon/enginex-mlu370-vllm
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import torch
|
|
from typing import Optional, Union
|
|
from vllm import utils
|
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8
|
|
|
|
|
|
def vllm__utils__get_kv_cache_torch_dtype(
|
|
cache_dtype: Optional[Union[str, torch.dtype]],
|
|
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: use STR_DTYPE_TO_TORCH_DTYPE to get torch_dtype
|
|
'''
|
|
if isinstance(cache_dtype, str):
|
|
if cache_dtype == "auto":
|
|
if isinstance(model_dtype, str):
|
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
|
elif isinstance(model_dtype, torch.dtype):
|
|
torch_dtype = model_dtype
|
|
else:
|
|
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
|
elif cache_dtype in ["half", "bfloat16", "float"]:
|
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
|
elif cache_dtype == "fp8":
|
|
torch_dtype = torch.uint8
|
|
elif cache_dtype == 'int8':
|
|
torch_dtype = torch.int8
|
|
else:
|
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
|
elif isinstance(cache_dtype, torch.dtype):
|
|
torch_dtype = cache_dtype
|
|
else:
|
|
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
|
return torch_dtype
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(utils,
|
|
utils.get_kv_cache_torch_dtype,
|
|
vllm__utils__get_kv_cache_torch_dtype) |