Fuse routed scaling factor in deepseek (#6970)
This commit is contained in:
199
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
199
benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.testing import do_bench
|
||||||
|
|
||||||
|
|
||||||
|
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||||
|
@triton.jit
|
||||||
|
def _moe_sum_reduce_kernel(
|
||||||
|
input_ptr,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_ptr,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
token_num: int,
|
||||||
|
topk_num: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
routed_scaling_factor: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DIM: tl.constexpr,
|
||||||
|
NUM_STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
||||||
|
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
||||||
|
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
||||||
|
|
||||||
|
token_block_id = tl.program_id(0)
|
||||||
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
token_start = token_block_id * BLOCK_M
|
||||||
|
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
||||||
|
|
||||||
|
dim_start = dim_block_id * BLOCK_DIM
|
||||||
|
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
||||||
|
|
||||||
|
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
||||||
|
|
||||||
|
for token_index in range(token_start, token_end):
|
||||||
|
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
||||||
|
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
||||||
|
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||||
|
tmp = tl.load(
|
||||||
|
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
||||||
|
)
|
||||||
|
accumulator += tmp
|
||||||
|
accumulator = accumulator * routed_scaling_factor
|
||||||
|
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
||||||
|
tl.store(
|
||||||
|
store_t_ptr,
|
||||||
|
accumulator.to(input_ptr.dtype.element_ty),
|
||||||
|
mask=offs_dim < dim_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_sum_reduce(
|
||||||
|
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||||
|
):
|
||||||
|
assert input.is_contiguous()
|
||||||
|
assert output.is_contiguous()
|
||||||
|
|
||||||
|
token_num, topk_num, hidden_dim = input.shape
|
||||||
|
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
||||||
|
|
||||||
|
BLOCK_M = 1
|
||||||
|
BLOCK_DIM = 2048
|
||||||
|
NUM_STAGE = 1
|
||||||
|
num_warps = 8
|
||||||
|
|
||||||
|
grid = (
|
||||||
|
triton.cdiv(token_num, BLOCK_M),
|
||||||
|
triton.cdiv(hidden_dim, BLOCK_DIM),
|
||||||
|
)
|
||||||
|
|
||||||
|
_moe_sum_reduce_kernel[grid](
|
||||||
|
input,
|
||||||
|
*input.stride(),
|
||||||
|
output,
|
||||||
|
*output.stride(),
|
||||||
|
token_num=token_num,
|
||||||
|
topk_num=topk_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_DIM=BLOCK_DIM,
|
||||||
|
NUM_STAGE=NUM_STAGE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sum_scaled_baseline(
|
||||||
|
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
torch.sum(x, dim=1, out=out)
|
||||||
|
out.mul_(routed_scaling_factor)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def compute_sum_scaled_compiled(
|
||||||
|
x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
torch.sum(x * routed_scaling_factor, dim=1, out=out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark():
|
||||||
|
num_tokens_range = [2**i for i in range(0, 13)]
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["num_tokens"],
|
||||||
|
x_vals=num_tokens_range,
|
||||||
|
line_arg="version",
|
||||||
|
line_vals=["baseline", "compiled", "triton"],
|
||||||
|
line_names=["Original", "TorchCompile", "TritonKernel"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="sum_scaled_performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(num_tokens, version):
|
||||||
|
topk = 9
|
||||||
|
hidden_size = 4096
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
scaling_factor = 0.3
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(3):
|
||||||
|
if version == "baseline":
|
||||||
|
compute_sum_scaled_baseline(x, out, scaling_factor)
|
||||||
|
elif version == "compiled":
|
||||||
|
compute_sum_scaled_compiled(x, out, scaling_factor)
|
||||||
|
else:
|
||||||
|
moe_sum_reduce(x, out, scaling_factor)
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
if version == "baseline":
|
||||||
|
ms, min_ms, max_ms = do_bench(
|
||||||
|
lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif version == "compiled":
|
||||||
|
ms, min_ms, max_ms = do_bench(
|
||||||
|
lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ms, min_ms, max_ms = do_bench(
|
||||||
|
lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
def verify_correctness(num_tokens=1024):
|
||||||
|
x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=torch.bfloat16)
|
||||||
|
scaling_factor = 0.3
|
||||||
|
|
||||||
|
out_baseline = torch.empty_like(x[:, 0])
|
||||||
|
compute_sum_scaled_baseline(x, out_baseline, scaling_factor)
|
||||||
|
|
||||||
|
out_compiled = torch.empty_like(out_baseline)
|
||||||
|
compute_sum_scaled_compiled(x, out_compiled, scaling_factor)
|
||||||
|
|
||||||
|
out_triton = torch.empty_like(out_baseline)
|
||||||
|
moe_sum_reduce(x, out_triton, scaling_factor)
|
||||||
|
|
||||||
|
if torch.allclose(
|
||||||
|
out_baseline, out_compiled, atol=1e-2, rtol=1e-2
|
||||||
|
) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ All implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
print(
|
||||||
|
f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
|
||||||
|
)
|
||||||
|
print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Running correctness verification...")
|
||||||
|
verify_correctness()
|
||||||
|
|
||||||
|
print("\nRunning performance benchmark...")
|
||||||
|
benchmark = get_benchmark()
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
# save_path="./configs/benchmark_ops/sum_scaled/"
|
||||||
|
)
|
||||||
@@ -1155,6 +1155,7 @@ def inplace_fused_experts(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
fused_experts_impl(
|
fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1177,6 +1178,8 @@ def inplace_fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
False,
|
||||||
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1200,6 +1203,7 @@ def inplace_fused_experts_fake(
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1233,6 +1237,7 @@ def outplace_fused_experts(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -1256,6 +1261,7 @@ def outplace_fused_experts(
|
|||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1280,6 +1286,7 @@ def outplace_fused_experts_fake(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@@ -1314,7 +1321,9 @@ def fused_experts(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if inplace:
|
if inplace:
|
||||||
assert not no_combine, "no combine + inplace makes no sense"
|
assert not no_combine, "no combine + inplace makes no sense"
|
||||||
torch.ops.sglang.inplace_fused_experts(
|
torch.ops.sglang.inplace_fused_experts(
|
||||||
@@ -1337,6 +1346,7 @@ def fused_experts(
|
|||||||
a1_scale,
|
a1_scale,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -1361,9 +1371,102 @@ def fused_experts(
|
|||||||
a2_scale,
|
a2_scale,
|
||||||
block_shape,
|
block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
|
||||||
|
@triton.jit
|
||||||
|
def _moe_sum_reduce_kernel(
|
||||||
|
input_ptr,
|
||||||
|
input_stride_0,
|
||||||
|
input_stride_1,
|
||||||
|
input_stride_2,
|
||||||
|
output_ptr,
|
||||||
|
output_stride_0,
|
||||||
|
output_stride_1,
|
||||||
|
token_num: int,
|
||||||
|
topk_num: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
routed_scaling_factor: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DIM: tl.constexpr,
|
||||||
|
NUM_STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
|
||||||
|
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
|
||||||
|
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
|
||||||
|
|
||||||
|
token_block_id = tl.program_id(0)
|
||||||
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
token_start = token_block_id * BLOCK_M
|
||||||
|
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
|
||||||
|
|
||||||
|
dim_start = dim_block_id * BLOCK_DIM
|
||||||
|
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
|
||||||
|
|
||||||
|
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
|
||||||
|
|
||||||
|
for token_index in range(token_start, token_end):
|
||||||
|
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
|
||||||
|
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
|
||||||
|
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
|
||||||
|
tmp = tl.load(
|
||||||
|
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
|
||||||
|
)
|
||||||
|
accumulator += tmp
|
||||||
|
accumulator = accumulator * routed_scaling_factor
|
||||||
|
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
|
||||||
|
tl.store(
|
||||||
|
store_t_ptr,
|
||||||
|
accumulator.to(input_ptr.dtype.element_ty),
|
||||||
|
mask=offs_dim < dim_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_sum_reduce_triton(
|
||||||
|
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
|
||||||
|
):
|
||||||
|
assert input.is_contiguous()
|
||||||
|
assert output.is_contiguous()
|
||||||
|
|
||||||
|
token_num, topk_num, hidden_dim = input.shape
|
||||||
|
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
|
||||||
|
|
||||||
|
BLOCK_M = 1
|
||||||
|
BLOCK_DIM = 2048
|
||||||
|
NUM_STAGE = 1
|
||||||
|
num_warps = 8
|
||||||
|
|
||||||
|
grid = (
|
||||||
|
triton.cdiv(token_num, BLOCK_M),
|
||||||
|
triton.cdiv(hidden_dim, BLOCK_DIM),
|
||||||
|
)
|
||||||
|
|
||||||
|
_moe_sum_reduce_kernel[grid](
|
||||||
|
input,
|
||||||
|
*input.stride(),
|
||||||
|
output,
|
||||||
|
*output.stride(),
|
||||||
|
token_num=token_num,
|
||||||
|
topk_num=topk_num,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_DIM=BLOCK_DIM,
|
||||||
|
NUM_STAGE=NUM_STAGE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
|
||||||
|
torch.sum(x, dim=1, out=out)
|
||||||
|
out.mul_(routed_scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(
|
def fused_experts_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@@ -1386,6 +1489,7 @@ def fused_experts_impl(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
no_combine: bool = False,
|
no_combine: bool = False,
|
||||||
|
routed_scaling_factor: Optional[float] = None,
|
||||||
):
|
):
|
||||||
padded_size = padding_size
|
padded_size = padding_size
|
||||||
if (
|
if (
|
||||||
@@ -1562,28 +1666,39 @@ def fused_experts_impl(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if routed_scaling_factor is None:
|
||||||
|
routed_scaling_factor = 1.0
|
||||||
|
|
||||||
if no_combine:
|
if no_combine:
|
||||||
pass
|
pass
|
||||||
elif _is_hip:
|
elif _is_cuda:
|
||||||
vllm_ops.moe_sum(
|
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if topk_ids.shape[1] == 1:
|
|
||||||
pass # we write directly into out_hidden_states
|
pass # we write directly into out_hidden_states
|
||||||
elif topk_ids.shape[1] == 2:
|
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
|
||||||
torch.add(
|
torch.add(
|
||||||
intermediate_cache3[:, 0],
|
intermediate_cache3[:, 0],
|
||||||
intermediate_cache3[:, 1],
|
intermediate_cache3[:, 1],
|
||||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
).squeeze(dim=1)
|
).squeeze(dim=1)
|
||||||
elif topk_ids.shape[1] > 2:
|
else:
|
||||||
torch.sum(
|
# According to micro benchmark results, torch.compile can get better performance for small token.
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
if tokens_in_chunk <= 32:
|
||||||
dim=1,
|
moe_sum_reduce_torch_compile(
|
||||||
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
)
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
moe_sum_reduce_triton(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
routed_scaling_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
vllm_ops.moe_sum(
|
||||||
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
)
|
||||||
|
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
@@ -1695,4 +1810,5 @@ def fused_moe(
|
|||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
|
|||||||
@@ -411,4 +411,5 @@ class BlockInt8MoEMethod:
|
|||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -317,6 +317,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1030,6 +1030,7 @@ class Fp8MoEMethod:
|
|||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
block_shape=self.quant_config.weight_block_size,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
def maybe_apply_hip_fused_experts(
|
def maybe_apply_hip_fused_experts(
|
||||||
|
|||||||
@@ -388,6 +388,7 @@ class MoeWNA16Method:
|
|||||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||||
block_shape=[0, layer.group_size],
|
block_shape=[0, layer.group_size],
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -328,4 +328,5 @@ class W8A8FP8MoEMethod:
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -268,4 +268,5 @@ class W8A8Int8MoEMethod:
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
no_combine=no_combine,
|
no_combine=no_combine,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -346,7 +346,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
if not _is_cuda:
|
||||||
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user