Support double sparsity (#1459)

This commit is contained in:
Shuo Yang
2024-10-14 02:00:41 -07:00
committed by GitHub
parent 0c1e87964b
commit 061e546313
8 changed files with 1269 additions and 1 deletions

View File

@@ -86,6 +86,14 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args: str = "{}"
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
@@ -443,6 +451,43 @@ class ServerArgs:
default=ServerArgs.json_model_override_args,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# LoRA
parser.add_argument(
"--lora-paths",