Support Llama4 fp8 inference (#5194)

Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
HandH1998
2025-04-09 20:14:34 +08:00
committed by GitHub
parent 86a876d883
commit 4065248214
14 changed files with 537 additions and 106 deletions

View File

@@ -129,7 +129,9 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(
model_config: ModelConfig, load_config: LoadConfig
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
@@ -147,6 +149,7 @@ def get_quant_config(
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
if hf_quant_config is not None:
hf_quant_config["packed_modules_mapping"] = packed_modules_mapping
return quant_cls.from_config(hf_quant_config)
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes":