Support Llama4 fp8 inference (#5194)

Co-authored-by: laixinn <xielx@shanghaitech.edu.cn>
Co-authored-by: sleepcoo <sleepcoo@gmail.com>
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
HandH1998
2025-04-09 20:14:34 +08:00
committed by GitHub
parent 86a876d883
commit 4065248214
14 changed files with 537 additions and 106 deletions

View File

@@ -414,7 +414,7 @@ class Llama4Model(nn.Module):
lambda idx, prefix: Llama4DecoderLayer(
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
),
prefix="model.layers",
prefix=add_prefix("layers", prefix),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)