Add support for nvidia modelopt fp8 kv cache (#3223)
This commit is contained in:
@@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
return remapped_name
|
||||
|
||||
possible_scale_names = [".k_scale", ".v_scale"]
|
||||
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
|
||||
for scale_name in possible_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
# Check and remap the name based on modelopt scale names
|
||||
if any(
|
||||
modelopt_scale_name in name
|
||||
for modelopt_scale_name in modelopt_scale_names
|
||||
):
|
||||
remapped_name = name.replace(
|
||||
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||
f".self_attn.attn{scale_name}",
|
||||
)
|
||||
else:
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if remapped_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||
|
||||
Reference in New Issue
Block a user