Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -399,7 +399,12 @@ def main(args: argparse.Namespace):
|
|||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]:
|
||||||
E = config.n_routed_experts
|
n_share_fusion_experts = args.n_share_experts_fusion
|
||||||
|
E = (
|
||||||
|
config.n_routed_experts + n_share_fusion_experts
|
||||||
|
if config.architectures[0] in ["DeepseekV3ForCausalLM"]
|
||||||
|
else config.n_routed_experts
|
||||||
|
)
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
@@ -559,6 +564,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--tune", action="store_true")
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-share-experts-fusion",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -993,13 +993,16 @@ async def benchmark(
|
|||||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
print("Starting initial single prompt test run...")
|
print(f"Starting warmup with {args.warmup_requests} sequences...")
|
||||||
|
|
||||||
|
# Use the first request for all warmup iterations
|
||||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||||
if lora_names != None and len(lora_names) != 0:
|
if lora_names != None and len(lora_names) != 0:
|
||||||
lora_name = lora_names[0]
|
lora_name = lora_names[0]
|
||||||
else:
|
else:
|
||||||
lora_name = None
|
lora_name = None
|
||||||
|
|
||||||
|
# Create the test input once
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_prompt,
|
prompt=test_prompt,
|
||||||
@@ -1009,14 +1012,26 @@ async def benchmark(
|
|||||||
lora_name=lora_name,
|
lora_name=lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
test_output = await request_func(request_func_input=test_input)
|
|
||||||
if not test_output.success:
|
# Run warmup requests
|
||||||
|
warmup_tasks = []
|
||||||
|
for _ in range(args.warmup_requests):
|
||||||
|
warmup_tasks.append(
|
||||||
|
asyncio.create_task(request_func(request_func_input=test_input))
|
||||||
|
)
|
||||||
|
|
||||||
|
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
||||||
|
|
||||||
|
# Check if at least one warmup request succeeded
|
||||||
|
if not any(output.success for output in warmup_outputs):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Initial test run failed - Please make sure benchmark arguments "
|
"Warmup failed - Please make sure benchmark arguments "
|
||||||
f"are correctly specified. Error: {test_output.error}"
|
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print(
|
||||||
|
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
|
||||||
|
)
|
||||||
|
|
||||||
# Flush cache
|
# Flush cache
|
||||||
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
||||||
@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
if not hasattr(args, "max_concurrency"):
|
if not hasattr(args, "max_concurrency"):
|
||||||
args.max_concurrency = None
|
args.max_concurrency = None
|
||||||
|
|
||||||
|
# Set default value for warmup_requests if not present
|
||||||
|
if not hasattr(args, "warmup_requests"):
|
||||||
|
args.warmup_requests = 1
|
||||||
|
|
||||||
print(f"benchmark_args={args}")
|
print(f"benchmark_args={args}")
|
||||||
|
|
||||||
# Set global environments
|
# Set global environments
|
||||||
@@ -1560,6 +1579,12 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Flush the cache before running the benchmark",
|
help="Flush the cache before running the benchmark",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--warmup-requests",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of warmup requests to run before the benchmark",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
|||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 5
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 16,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 32,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 1,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 4
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"32": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"48": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"64": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"128": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"256": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"512": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1024": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"1536": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"2048": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 16,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"3072": {
|
||||||
|
"BLOCK_SIZE_M": 128,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 32,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
},
|
||||||
|
"4096": {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 128,
|
||||||
|
"BLOCK_SIZE_K": 128,
|
||||||
|
"GROUP_SIZE_M": 64,
|
||||||
|
"num_warps": 4,
|
||||||
|
"num_stages": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,11 +13,6 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
|
||||||
from sglang.srt.layers.quantization.int8_kernel import (
|
|
||||||
per_token_group_quant_int8,
|
|
||||||
per_token_quant_int8,
|
|
||||||
)
|
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
@@ -42,9 +37,6 @@ if _is_cuda:
|
|||||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||||
|
|
||||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
|
||||||
sglang_per_token_group_quant_fp8,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel(
|
|||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from sglang.srt.layers.quantization.int8_kernel import (
|
||||||
|
per_token_group_quant_int8,
|
||||||
|
per_token_quant_int8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
sglang_per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -12,12 +12,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -102,11 +104,13 @@ def grouped_topk(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
|
n_share_experts_fusion: int = 0,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
num_token = scores.shape[0]
|
num_token = scores.shape[0]
|
||||||
|
num_experts = scores.shape[1]
|
||||||
group_scores = (
|
group_scores = (
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
) # [n, n_group]
|
) # [n, n_group]
|
||||||
@@ -122,9 +126,25 @@ def grouped_topk(
|
|||||||
) # [n, e]
|
) # [n, e]
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
if n_share_experts_fusion:
|
||||||
|
topk_ids[:, -1] = torch.randint(
|
||||||
|
low=num_experts,
|
||||||
|
high=num_experts + n_share_experts_fusion,
|
||||||
|
size=(topk_ids.size(0),),
|
||||||
|
dtype=topk_ids.dtype,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
topk_weights[:, -1] = (
|
||||||
|
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||||
|
) # 2.5 is the routed_scaling_factor.
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights_sum = (
|
||||||
|
topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
if n_share_experts_fusion == 0
|
||||||
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
|
||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
|
n_share_experts_fusion: int = 0,
|
||||||
):
|
):
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
scores = gating_output.sigmoid()
|
scores = gating_output.sigmoid()
|
||||||
num_token = scores.shape[0]
|
num_token = scores.shape[0]
|
||||||
|
num_experts = scores.shape[1]
|
||||||
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
||||||
group_scores = (
|
group_scores = (
|
||||||
scores_for_choice.view(num_token, num_expert_group, -1)
|
scores_for_choice.view(num_token, num_expert_group, -1)
|
||||||
@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
|
|||||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
topk_weights = scores.gather(1, topk_ids)
|
topk_weights = scores.gather(1, topk_ids)
|
||||||
|
|
||||||
|
if n_share_experts_fusion:
|
||||||
|
topk_ids[:, -1] = torch.randint(
|
||||||
|
low=num_experts,
|
||||||
|
high=num_experts + n_share_experts_fusion,
|
||||||
|
size=(topk_ids.size(0),),
|
||||||
|
dtype=topk_ids.dtype,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
topk_weights[:, -1] = (
|
||||||
|
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||||
|
) # 2.5 is the routed_scaling_factor.
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights_sum = (
|
||||||
|
topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
if n_share_experts_fusion == 0
|
||||||
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||||
|
)
|
||||||
|
topk_weights = topk_weights / topk_weights_sum
|
||||||
|
|
||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
@@ -179,6 +218,7 @@ def biased_grouped_topk(
|
|||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
compiled: bool = True,
|
compiled: bool = True,
|
||||||
|
n_share_experts_fusion: int = 0,
|
||||||
):
|
):
|
||||||
biased_grouped_topk_fn = (
|
biased_grouped_topk_fn = (
|
||||||
torch.compile(
|
torch.compile(
|
||||||
@@ -195,6 +235,7 @@ def biased_grouped_topk(
|
|||||||
renormalize,
|
renormalize,
|
||||||
num_expert_group,
|
num_expert_group,
|
||||||
topk_group,
|
topk_group,
|
||||||
|
n_share_experts_fusion=n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -210,7 +251,10 @@ def select_experts(
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
torch_native: bool = False,
|
torch_native: bool = False,
|
||||||
):
|
):
|
||||||
# DeekSeekv2 uses grouped_top_k
|
n_share_experts_fusion = 0
|
||||||
|
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
||||||
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||||
|
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
||||||
if use_grouped_topk:
|
if use_grouped_topk:
|
||||||
assert topk_group is not None
|
assert topk_group is not None
|
||||||
assert num_expert_group is not None
|
assert num_expert_group is not None
|
||||||
@@ -222,6 +266,7 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
|
n_share_experts_fusion=n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = biased_grouped_topk(
|
topk_weights, topk_ids = biased_grouped_topk(
|
||||||
@@ -232,6 +277,7 @@ def select_experts(
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
|
n_share_experts_fusion=n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
elif torch_native and custom_routing_function is None:
|
elif torch_native and custom_routing_function is None:
|
||||||
topk_weights, topk_ids = fused_topk_native(
|
topk_weights, topk_ids = fused_topk_native(
|
||||||
|
|||||||
@@ -51,7 +51,6 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
||||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||||
@@ -203,6 +202,8 @@ def get_linear_quant_method(
|
|||||||
|
|
||||||
|
|
||||||
def gptq_get_quant_method(self, layer, prefix):
|
def gptq_get_quant_method(self, layer, prefix):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
return GPTQMarlinMoEMethod(self)
|
return GPTQMarlinMoEMethod(self)
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
layer.scheme = scheme
|
layer.scheme = scheme
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -4,18 +4,19 @@
|
|||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, List, Optional
|
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors import CompressionFormat
|
from compressed_tensors import CompressionFormat
|
||||||
from compressed_tensors.quantization import QuantizationStrategy
|
from compressed_tensors.quantization import QuantizationStrategy
|
||||||
|
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
if TYPE_CHECKING:
|
||||||
FusedMoE,
|
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||||
FusedMoEMethodBase,
|
FusedMoE,
|
||||||
FusedMoeWeightScaleSupported,
|
FusedMoEMethodBase,
|
||||||
)
|
FusedMoeWeightScaleSupported,
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
@@ -55,7 +56,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
class CompressedTensorsMoEMethod:
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||||
|
|
||||||
|
if cls is CompressedTensorsMoEMethod:
|
||||||
|
return super().__new__(cls)
|
||||||
|
return super().__new__(cls)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_moe_method(
|
def get_moe_method(
|
||||||
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
):
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported,
|
||||||
|
)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
inplace: bool = True,
|
||||||
|
no_combine: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||||
):
|
):
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||||
|
FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported,
|
||||||
|
)
|
||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
# are supported + check if the layer is being ignored.
|
# are supported + check if the layer is being ignored.
|
||||||
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
correction_bias: Optional[torch.Tensor] = None,
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
if not VLLM_AVAILABLE:
|
if not VLLM_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|||||||
@@ -81,6 +81,8 @@ global_server_args_dict = {
|
|||||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||||
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
||||||
|
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -157,6 +157,8 @@ class ModelRunner:
|
|||||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||||
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||||
|
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,12 +16,14 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||||
"""Inference-only DeepseekV2 model."""
|
"""Inference-only DeepseekV2 model."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
@@ -87,6 +89,8 @@ if _is_hip:
|
|||||||
|
|
||||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
self.n_shared_experts = config.n_shared_experts
|
self.n_shared_experts = config.n_shared_experts
|
||||||
|
self.n_share_experts_fusion = (
|
||||||
|
global_server_args_dict["n_share_experts_fusion"]
|
||||||
|
if global_server_args_dict["n_share_experts_fusion"] is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
self.routed_scaling_factor = config.routed_scaling_factor
|
self.routed_scaling_factor = config.routed_scaling_factor
|
||||||
if self.tp_size > config.n_routed_experts:
|
if self.tp_size > config.n_routed_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if global_server_args_dict["enable_deepep_moe"]
|
if global_server_args_dict["enable_deepep_moe"]
|
||||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.experts = MoEImpl(
|
self.experts = MoEImpl(
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
# disable tp for shared experts when enable deepep moe
|
# disable tp for shared experts when enable deepep moe
|
||||||
if not global_server_args_dict["enable_deepep_moe"]:
|
if not global_server_args_dict["enable_deepep_moe"]:
|
||||||
@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return self.forward_deepep(hidden_states, forward_mode)
|
return self.forward_deepep(hidden_states, forward_mode)
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
else:
|
||||||
|
shared_output = None
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||||
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
|
if (
|
||||||
|
global_server_args_dict.get("disable_shared_experts_fusion", False)
|
||||||
|
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||||
|
or self.config.n_routed_experts != 256
|
||||||
|
or self.config.routed_scaling_factor != 2.5
|
||||||
|
):
|
||||||
|
self.n_share_experts_fusion = None
|
||||||
|
global_server_args_dict["n_share_experts_fusion"] = None
|
||||||
|
logger.info(
|
||||||
|
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
||||||
|
)
|
||||||
|
elif self.n_share_experts_fusion is None:
|
||||||
|
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
||||||
|
self.n_share_experts_fusion = self.tp_size
|
||||||
|
logger.info(
|
||||||
|
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
|
||||||
|
)
|
||||||
|
|
||||||
self.model = DeepseekV2Model(
|
self.model = DeepseekV2Model(
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
||||||
|
weights_list = list(weights)
|
||||||
|
weights_dict = dict(weights_list)
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale_inv",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale_inv",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale_inv",
|
||||||
|
]
|
||||||
|
names_to_remove = []
|
||||||
|
for moe_layer in tqdm(
|
||||||
|
range(
|
||||||
|
self.config.first_k_dense_replace,
|
||||||
|
self.config.num_hidden_layers,
|
||||||
|
self.config.moe_layer_freq,
|
||||||
|
),
|
||||||
|
desc=f"Cloning {self.n_share_experts_fusion} "
|
||||||
|
"replicas of the shared expert into MoE",
|
||||||
|
):
|
||||||
|
for num_repeat in range(self.n_share_experts_fusion):
|
||||||
|
for suffix in suffix_list:
|
||||||
|
shared_expert_weight_name = (
|
||||||
|
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||||
|
)
|
||||||
|
weights_list.append(
|
||||||
|
(
|
||||||
|
f"model.layers.{moe_layer}."
|
||||||
|
f"mlp.experts."
|
||||||
|
f"{self.config.n_routed_experts + num_repeat}"
|
||||||
|
f".{suffix}",
|
||||||
|
weights_dict[shared_expert_weight_name].clone(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
names_to_remove += [shared_expert_weight_name]
|
||||||
|
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.config.n_routed_experts,
|
num_experts=self.config.n_routed_experts
|
||||||
|
+ (
|
||||||
|
self.n_share_experts_fusion
|
||||||
|
if self.n_share_experts_fusion is not None
|
||||||
|
else 0
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|||||||
@@ -183,6 +183,8 @@ class ServerArgs:
|
|||||||
enable_flashmla: bool = False
|
enable_flashmla: bool = False
|
||||||
flashinfer_mla_disable_ragged: bool = False
|
flashinfer_mla_disable_ragged: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
|
n_share_experts_fusion: Optional[int] = None
|
||||||
|
disable_shared_experts_fusion: bool = False
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -224,6 +226,9 @@ class ServerArgs:
|
|||||||
# GPU memory is not known yet or no GPU is available.
|
# GPU memory is not known yet or no GPU is available.
|
||||||
gpu_mem = None
|
gpu_mem = None
|
||||||
|
|
||||||
|
if is_hip():
|
||||||
|
self.disable_shared_experts_fusion = True
|
||||||
|
|
||||||
# Set mem fraction static, which depends on the tensor parallelism size
|
# Set mem fraction static, which depends on the tensor parallelism size
|
||||||
if self.mem_fraction_static is None:
|
if self.mem_fraction_static is None:
|
||||||
if self.tp_size >= 16:
|
if self.tp_size >= 16:
|
||||||
@@ -1102,6 +1107,19 @@ class ServerArgs:
|
|||||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n-share-experts-fusion",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
|
||||||
|
"we use tp_size by default.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-shared-experts-fusion",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
|
||||||
|
)
|
||||||
|
|
||||||
# Server warmups
|
# Server warmups
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--warmups",
|
"--warmups",
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ def moe_align_block_size_triton(
|
|||||||
[32, 64, 128, 256], # block_size
|
[32, 64, 128, 256], # block_size
|
||||||
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], # num_tokens
|
||||||
[1, 2, 4, 8, 16, 32, 64], # topk
|
[1, 2, 4, 8, 16, 32, 64], # topk
|
||||||
[64, 160, 256], # num_experts
|
[64, 160, 256, 257, 260, 264], # num_experts
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user