[CI] Improve Docs CI Efficiency (#3587)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -3,6 +3,7 @@ set -euxo pipefail
|
||||
|
||||
# Install the dependency in CI.
|
||||
|
||||
|
||||
# Use repo from environment variable, passed from GitHub Actions
|
||||
FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python}"
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from safetensors.torch import save_file
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def get_nexn_layer_id(config):
|
||||
def get_nextn_layer_id(config):
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
raise ValueError("'num_hidden_layers' not found in model config.")
|
||||
return config.num_hidden_layers
|
||||
@@ -25,7 +25,7 @@ def update_and_save_config(config, output_dir):
|
||||
new_config = config.to_dict()
|
||||
new_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"num_hidden_layers": 1,
|
||||
"architectures": ["DeepseekV3ForCausalLMNextN"],
|
||||
}
|
||||
)
|
||||
@@ -42,8 +42,8 @@ def copy_non_safetensors_files(input_dir, output_dir):
|
||||
print(f"All non-safetensors files have been copied to {output_dir}")
|
||||
|
||||
|
||||
def export_nextn_layer_parameters(input_dir, output_dir, nexn_layer_id):
|
||||
prefix = f"model.layers.{nexn_layer_id}"
|
||||
def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
|
||||
prefix = f"model.layers.{nextn_layer_id}"
|
||||
output_path = os.path.join(output_dir, "nextn_layer_parameters.safetensors")
|
||||
params = {}
|
||||
for filename in os.listdir(input_dir):
|
||||
@@ -106,7 +106,7 @@ if __name__ == "__main__":
|
||||
|
||||
config = AutoConfig.from_pretrained(args.input_dir, trust_remote_code=True)
|
||||
assert config.num_nextn_predict_layers == 1, "Only 1 nextn layer is supported."
|
||||
nextn_layer_id = get_nexn_layer_id(config)
|
||||
nextn_layer_id = get_nextn_layer_id(config)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
copy_non_safetensors_files(args.input_dir, args.output_dir)
|
||||
update_and_save_config(config, args.output_dir)
|
||||
|
||||
Reference in New Issue
Block a user