Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
packed_modules_mapping: Dict[str, List[str]],
|
||||
) -> QuantizationConfig:
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
@@ -147,6 +149,7 @@ def get_quant_config(
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
|
||||
if hf_quant_config is not None:
|
||||
hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
|
||||
Reference in New Issue
Block a user