Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -108,11 +108,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig, load_config: LoadConfig
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
packed_modules_mapping: Dict[str, List[str]],
|
||||
) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
quant_config = get_quant_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
major, minor = get_device_capability()
|
||||
|
||||
if major is not None and minor is not None:
|
||||
@@ -142,7 +146,10 @@ def _initialize_model(
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
quant_config = _get_quantization_config(model_config, load_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
return model_class(
|
||||
config=model_config.hf_config,
|
||||
quant_config=quant_config,
|
||||
|
||||
Reference in New Issue
Block a user