[2/2] Support Qserve (#6521)
This commit is contained in:
@@ -349,6 +349,7 @@ class ModelConfig:
|
||||
"w8a8_int8",
|
||||
"w8a8_fp8",
|
||||
"moe_wna16",
|
||||
"qoq",
|
||||
]
|
||||
compatible_quantization_methods = {
|
||||
"modelopt_fp4": ["modelopt"],
|
||||
@@ -458,6 +459,8 @@ def _get_and_verify_dtype(
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if isinstance(config_dtype, str):
|
||||
config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
|
||||
Reference in New Issue
Block a user