Support Llama4 fp8 inference (#5194)
Co-authored-by: laixinn <xielx@shanghaitech.edu.cn> Co-authored-by: sleepcoo <sleepcoo@gmail.com> Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -414,7 +414,7 @@ class Llama4Model(nn.Module):
|
||||
lambda idx, prefix: Llama4DecoderLayer(
|
||||
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
prefix="model.layers",
|
||||
prefix=add_prefix("layers", prefix),
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
Reference in New Issue
Block a user