Support Llama4 fp8 inference (#5194)

Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
HandH1998
2025-04-09 20:14:34 +08:00
committed by GitHub
parent 86a876d883
commit 4065248214
14 changed files with 537 additions and 106 deletions

View File

@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
def _get_quantization_config(
model_config: ModelConfig, load_config: LoadConfig
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping
)
major, minor = get_device_capability()
if major is not None and minor is not None:
@@ -142,7 +146,10 @@ def _initialize_model(
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
quant_config = _get_quantization_config(model_config, load_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
return model_class(
config=model_config.hf_config,
quant_config=quant_config,