Add DeepSeek V3/R1 shared experts fusion (#4918)
This commit is contained in:
@@ -993,13 +993,16 @@ async def benchmark(
|
||||
return await request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
|
||||
# Warmup
|
||||
print("Starting initial single prompt test run...")
|
||||
print(f"Starting warmup with {args.warmup_requests} sequences...")
|
||||
|
||||
# Use the first request for all warmup iterations
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
if lora_names != None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
else:
|
||||
lora_name = None
|
||||
|
||||
# Create the test input once
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
@@ -1009,14 +1012,26 @@ async def benchmark(
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
|
||||
# Run warmup requests
|
||||
warmup_tasks = []
|
||||
for _ in range(args.warmup_requests):
|
||||
warmup_tasks.append(
|
||||
asyncio.create_task(request_func(request_func_input=test_input))
|
||||
)
|
||||
|
||||
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
||||
|
||||
# Check if at least one warmup request succeeded
|
||||
if not any(output.success for output in warmup_outputs):
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
"Warmup failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
print(
|
||||
f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
|
||||
)
|
||||
|
||||
# Flush cache
|
||||
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
||||
@@ -1253,6 +1268,10 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "max_concurrency"):
|
||||
args.max_concurrency = None
|
||||
|
||||
# Set default value for warmup_requests if not present
|
||||
if not hasattr(args, "warmup_requests"):
|
||||
args.warmup_requests = 1
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
@@ -1560,6 +1579,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Flush the cache before running the benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-requests",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup requests to run before the benchmark",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -13,11 +13,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
direct_register_custom_op,
|
||||
get_bool_env_var,
|
||||
@@ -42,9 +37,6 @@ if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
@@ -764,6 +756,16 @@ def invoke_fused_moe_kernel(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> None:
|
||||
from sglang.srt.layers.quantization.int8_kernel import (
|
||||
per_token_group_quant_int8,
|
||||
per_token_quant_int8,
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
sglang_per_token_group_quant_fp8,
|
||||
)
|
||||
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -102,11 +104,13 @@ def grouped_topk(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = torch.softmax(gating_output, dim=-1)
|
||||
num_token = scores.shape[0]
|
||||
num_experts = scores.shape[1]
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
@@ -122,9 +126,25 @@ def grouped_topk(
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
if n_share_experts_fusion:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
@@ -137,11 +157,13 @@ def biased_grouped_topk_impl(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
scores = gating_output.sigmoid()
|
||||
num_token = scores.shape[0]
|
||||
num_experts = scores.shape[1]
|
||||
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
||||
group_scores = (
|
||||
scores_for_choice.view(num_token, num_expert_group, -1)
|
||||
@@ -164,8 +186,25 @@ def biased_grouped_topk_impl(
|
||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_ids)
|
||||
|
||||
if n_share_experts_fusion:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
|
||||
@@ -179,6 +218,7 @@ def biased_grouped_topk(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
):
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
@@ -195,6 +235,7 @@ def biased_grouped_topk(
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
|
||||
|
||||
@@ -210,7 +251,10 @@ def select_experts(
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
):
|
||||
# DeekSeekv2 uses grouped_top_k
|
||||
n_share_experts_fusion = 0
|
||||
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
@@ -222,6 +266,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -232,6 +277,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
|
||||
@@ -51,7 +51,6 @@ except ImportError:
|
||||
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.quantization.awq import AWQConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
||||
@@ -203,6 +202,8 @@ def get_linear_quant_method(
|
||||
|
||||
|
||||
def gptq_get_quant_method(self, layer, prefix):
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
layer.scheme = scheme
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||
return None
|
||||
|
||||
@@ -4,18 +4,19 @@
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
@@ -55,7 +56,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
class CompressedTensorsMoEMethod:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||
|
||||
if cls is CompressedTensorsMoEMethod:
|
||||
return super().__new__(cls)
|
||||
return super().__new__(cls)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
|
||||
@@ -81,6 +81,8 @@ global_server_args_dict = {
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -157,6 +157,8 @@ class ModelRunner:
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
||||
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
||||
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
||||
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,12 +16,14 @@
|
||||
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
@@ -87,6 +89,8 @@ if _is_hip:
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -168,6 +172,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.n_share_experts_fusion = (
|
||||
global_server_args_dict["n_share_experts_fusion"]
|
||||
if global_server_args_dict["n_share_experts_fusion"] is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
if self.tp_size > config.n_routed_experts:
|
||||
raise ValueError(
|
||||
@@ -188,9 +198,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
||||
)
|
||||
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
||||
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
@@ -207,7 +218,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
# disable tp for shared experts when enable deepep moe
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
@@ -267,8 +278,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.n_shared_experts is not None:
|
||||
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = (
|
||||
@@ -1315,7 +1328,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
global_server_args_dict.get("disable_shared_experts_fusion", False)
|
||||
or self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||
or self.config.n_routed_experts != 256
|
||||
or self.config.routed_scaling_factor != 2.5
|
||||
):
|
||||
self.n_share_experts_fusion = None
|
||||
global_server_args_dict["n_share_experts_fusion"] = None
|
||||
logger.info(
|
||||
"Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled."
|
||||
)
|
||||
elif self.n_share_experts_fusion is None:
|
||||
global_server_args_dict["n_share_experts_fusion"] = self.tp_size
|
||||
self.n_share_experts_fusion = self.tp_size
|
||||
logger.info(
|
||||
f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion."
|
||||
)
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
@@ -1352,6 +1386,43 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0:
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale_inv",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale_inv",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
names_to_remove = []
|
||||
for moe_layer in tqdm(
|
||||
range(
|
||||
self.config.first_k_dense_replace,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.moe_layer_freq,
|
||||
),
|
||||
desc=f"Cloning {self.n_share_experts_fusion} "
|
||||
"replicas of the shared expert into MoE",
|
||||
):
|
||||
for num_repeat in range(self.n_share_experts_fusion):
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.{moe_layer}."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + num_repeat}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name].clone(),
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -1364,7 +1435,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_experts=self.config.n_routed_experts
|
||||
+ (
|
||||
self.n_share_experts_fusion
|
||||
if self.n_share_experts_fusion is not None
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -183,6 +183,8 @@ class ServerArgs:
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
n_share_experts_fusion: Optional[int] = None
|
||||
disable_shared_experts_fusion: bool = False
|
||||
|
||||
# Debug tensor dumps
|
||||
debug_tensor_dump_output_folder: Optional[str] = None
|
||||
@@ -224,6 +226,9 @@ class ServerArgs:
|
||||
# GPU memory is not known yet or no GPU is available.
|
||||
gpu_mem = None
|
||||
|
||||
if is_hip():
|
||||
self.disable_shared_experts_fusion = True
|
||||
|
||||
# Set mem fraction static, which depends on the tensor parallelism size
|
||||
if self.mem_fraction_static is None:
|
||||
if self.tp_size >= 16:
|
||||
@@ -1102,6 +1107,19 @@ class ServerArgs:
|
||||
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n-share-experts-fusion",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 "
|
||||
"we use tp_size by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-shared-experts-fusion",
|
||||
action="store_true",
|
||||
help="Disable shared experts fusion by setting n_share_experts_fusion to 0.",
|
||||
)
|
||||
|
||||
# Server warmups
|
||||
parser.add_argument(
|
||||
"--warmups",
|
||||
|
||||
Reference in New Issue
Block a user