Support double sparsity (#1459)
This commit is contained in:
@@ -86,6 +86,14 @@ class ServerArgs:
|
||||
# Model override args in JSON
|
||||
json_model_override_args: str = "{}"
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
ds_channel_config_path: str = None
|
||||
ds_heavy_channel_num: int = 32
|
||||
ds_heavy_token_num: int = 256
|
||||
ds_heavy_channel_type: str = "qk"
|
||||
ds_sparse_decode_threshold: int = 4096
|
||||
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
@@ -443,6 +451,43 @@ class ServerArgs:
|
||||
default=ServerArgs.json_model_override_args,
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
"--enable-double-sparsity",
|
||||
action="store_true",
|
||||
help="Enable double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-channel-config-path",
|
||||
type=str,
|
||||
default=ServerArgs.ds_channel_config_path,
|
||||
help="The path of the double sparsity channel config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_channel_num,
|
||||
help="The number of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-token-num",
|
||||
type=int,
|
||||
default=ServerArgs.ds_heavy_token_num,
|
||||
help="The number of heavy tokens in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-heavy-channel-type",
|
||||
type=str,
|
||||
default=ServerArgs.ds_heavy_channel_type,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds-sparse-decode-threshold",
|
||||
type=int,
|
||||
default=ServerArgs.ds_sparse_decode_threshold,
|
||||
help="The type of heavy channels in double sparsity attention",
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
|
||||
Reference in New Issue
Block a user