Optimize Permute Kernel in DeepEP (#4643)
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
This commit is contained in:
@@ -17,52 +17,6 @@ if _is_cuda:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def compute_src2dst_triton_kernel(
|
|
||||||
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
|
||||||
):
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = dst_id < num_toks
|
|
||||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
|
||||||
tl.store(src2dst + src_id, dst_id, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def deepep_compute_src2dst_triton_kernel(
|
|
||||||
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
|
|
||||||
):
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = dst_id < num_toks
|
|
||||||
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
|
||||||
num_invalid = tl.load(num_minus_one)
|
|
||||||
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|
||||||
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
|
||||||
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
|
||||||
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
|
||||||
|
|
||||||
# Find offet
|
|
||||||
expert_ids = torch.arange(
|
|
||||||
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
|
||||||
)
|
|
||||||
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
|
|
||||||
num_minus_one = seg_indptr[0]
|
|
||||||
seg_indptr = seg_indptr - num_minus_one
|
|
||||||
|
|
||||||
BLOCK_SIZE = 512
|
|
||||||
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
|
||||||
deepep_compute_src2dst_triton_kernel[grid](
|
|
||||||
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
|
|
||||||
)
|
|
||||||
|
|
||||||
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
|
|
||||||
return reorder_topk_ids, src2dst, seg_indptr
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def deepep_permute_triton_kernel(
|
def deepep_permute_triton_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
@@ -85,14 +39,13 @@ def deepep_permute_triton_kernel(
|
|||||||
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
||||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
mask = offset < hidden_size
|
mask = offset < hidden_size
|
||||||
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)
|
||||||
|
|
||||||
for idx in range(topk):
|
for idx in range(topk):
|
||||||
dst_idx = tl.load(src2dst_ptr + idx)
|
dst_idx = tl.load(src2dst_ptr + idx)
|
||||||
if dst_idx >= 0:
|
if dst_idx >= 0:
|
||||||
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
||||||
out_data = (in_data).to(OutDtype)
|
tl.store(dst_ptr + offset, in_data, mask=mask)
|
||||||
tl.store(dst_ptr + offset, out_data, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -128,6 +81,51 @@ def deepep_post_reorder_triton_kernel(
|
|||||||
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def compute_src2dst_triton_kernel(
|
||||||
|
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = dst_id < num_toks
|
||||||
|
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||||
|
tl.store(src2dst + src_id, dst_id, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def deepep_compute_src2dst_triton_kernel(
|
||||||
|
reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = dst_id < num_toks
|
||||||
|
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
||||||
|
num_invalid = tl.load(num_minus_one)
|
||||||
|
tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
||||||
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
||||||
|
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
||||||
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
||||||
|
|
||||||
|
# Find offet
|
||||||
|
expert_ids = torch.arange(
|
||||||
|
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
||||||
|
)
|
||||||
|
torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
|
||||||
|
num_minus_one = seg_indptr[0]
|
||||||
|
seg_indptr = seg_indptr - num_minus_one
|
||||||
|
|
||||||
|
BLOCK_SIZE = 512
|
||||||
|
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
||||||
|
deepep_compute_src2dst_triton_kernel[grid](
|
||||||
|
reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
|
||||||
|
)
|
||||||
|
reorder_topk_ids = reorder_topk_ids[num_minus_one:]
|
||||||
|
return reorder_topk_ids, src2dst, seg_indptr
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
||||||
expert = tl.program_id(0)
|
expert = tl.program_id(0)
|
||||||
|
|||||||
@@ -831,19 +831,23 @@ class DeepEPMoE(EPMoE):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
tokens_per_expert: torch.Tensor,
|
reorder_topk_ids: torch.Tensor,
|
||||||
|
seg_indptr: torch.Tensor,
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
# Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
|
||||||
if True: # not forward_mode.is_decode():
|
if True: # not forward_mode.is_decode():
|
||||||
return self.forward_normal(hidden_states, tokens_per_expert)
|
return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepgemm_masked(hidden_states, tokens_per_expert)
|
return self.forward_deepgemm_masked(
|
||||||
|
hidden_states, reorder_topk_ids, seg_indptr
|
||||||
|
)
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
tokens_per_expert: torch.Tensor,
|
reorder_topk_ids: torch.Tensor,
|
||||||
|
seg_indptr: torch.Tensor,
|
||||||
):
|
):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
assert self.activation == "silu"
|
assert self.activation == "silu"
|
||||||
@@ -851,15 +855,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||||
)
|
)
|
||||||
seg_indptr_cur_rank = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype
|
|
||||||
),
|
|
||||||
torch.cumsum(tokens_per_expert, dim=0),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
reorder_topk_ids = torch.repeat_interleave(tokens_per_expert)
|
|
||||||
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
||||||
max_value = (
|
max_value = (
|
||||||
torch.max(hidden_states)
|
torch.max(hidden_states)
|
||||||
@@ -881,6 +877,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
gateup_output = self.grouped_gemm_runner(
|
gateup_output = self.grouped_gemm_runner(
|
||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
@@ -888,7 +885,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
c=gateup_output,
|
c=gateup_output,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_experts_per_partition,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
seg_indptr=seg_indptr,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w13_input_scale,
|
scale_a=self.w13_input_scale,
|
||||||
@@ -946,7 +943,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
c=down_output,
|
c=down_output,
|
||||||
batch_size=self.num_experts_per_partition,
|
batch_size=self.num_experts_per_partition,
|
||||||
weight_column_major=True,
|
weight_column_major=True,
|
||||||
seg_indptr=seg_indptr_cur_rank,
|
seg_indptr=seg_indptr,
|
||||||
weight_indices=weight_indices_cur_rank,
|
weight_indices=weight_indices_cur_rank,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
scale_a=self.w2_input_scale,
|
scale_a=self.w2_input_scale,
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
compute_src2dst_triton_kernel,
|
|
||||||
deepep_permute_triton_kernel,
|
deepep_permute_triton_kernel,
|
||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
deepep_run_moe_deep_preprocess,
|
deepep_run_moe_deep_preprocess,
|
||||||
@@ -86,90 +85,6 @@ def get_buffer_low_latency(
|
|||||||
return _buffer_low_latency
|
return _buffer_low_latency
|
||||||
|
|
||||||
|
|
||||||
def permute(
|
|
||||||
tokens,
|
|
||||||
routing_map,
|
|
||||||
num_out_tokens: Optional[int] = None,
|
|
||||||
fused: bool = False,
|
|
||||||
drop_and_pad: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Copy from Megatron-Core moe for token permutation
|
|
||||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_tokens, _ = tokens.shape
|
|
||||||
num_experts = routing_map.shape[1]
|
|
||||||
if drop_and_pad and not (num_out_tokens is None):
|
|
||||||
capacity = num_out_tokens // num_experts
|
|
||||||
assert not routing_map.requires_grad
|
|
||||||
routing_map = routing_map.to(dtype=torch.int8).T.contiguous()
|
|
||||||
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
|
|
||||||
:, :capacity
|
|
||||||
].contiguous()
|
|
||||||
sorted_indices = sorted_indices.view(-1)
|
|
||||||
else:
|
|
||||||
routing_map = routing_map.bool().T.contiguous()
|
|
||||||
token_indices = (
|
|
||||||
torch.arange(num_tokens, device=routing_map.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(num_experts, -1)
|
|
||||||
)
|
|
||||||
sorted_indices = token_indices.masked_select(routing_map)
|
|
||||||
permuted_input = tokens.index_select(0, sorted_indices)
|
|
||||||
|
|
||||||
return permuted_input, sorted_indices
|
|
||||||
|
|
||||||
|
|
||||||
def unpermute(
|
|
||||||
permuted_tokens: torch.Tensor,
|
|
||||||
sorted_indices: torch.Tensor,
|
|
||||||
restore_shape: torch.Size,
|
|
||||||
probs: torch.Tensor = None,
|
|
||||||
routing_map: torch.Tensor = None,
|
|
||||||
fused: bool = False,
|
|
||||||
drop_and_pad: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Copy from Megatron-Core moe for token unpermutation
|
|
||||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
_, hidden = restore_shape
|
|
||||||
|
|
||||||
if probs is not None:
|
|
||||||
assert routing_map is not None, "Mask must be provided to permute the probs."
|
|
||||||
if drop_and_pad:
|
|
||||||
num_experts = routing_map.size(1)
|
|
||||||
num_permuted_tokens = sorted_indices.size(0)
|
|
||||||
capacity = num_permuted_tokens // num_experts
|
|
||||||
num_unpermuted_tokens = probs.size(0)
|
|
||||||
|
|
||||||
probs_T_1D = probs.T.contiguous().view(-1)
|
|
||||||
|
|
||||||
indices_dim0 = torch.arange(
|
|
||||||
num_experts, device=routing_map.device
|
|
||||||
).unsqueeze(-1)
|
|
||||||
indices_dim1 = sorted_indices.view(num_experts, capacity)
|
|
||||||
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
|
|
||||||
|
|
||||||
permuted_probs = probs_T_1D.index_select(0, indices_1D)
|
|
||||||
else:
|
|
||||||
permuted_probs = probs.T.contiguous().masked_select(
|
|
||||||
routing_map.T.contiguous()
|
|
||||||
)
|
|
||||||
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
|
|
||||||
|
|
||||||
output_tokens = torch.zeros(
|
|
||||||
restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype
|
|
||||||
)
|
|
||||||
output_tokens.scatter_add_(
|
|
||||||
0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
return output_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class DeepEPDispatcher:
|
class DeepEPDispatcher:
|
||||||
"""
|
"""
|
||||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||||
@@ -228,16 +143,13 @@ class DeepEPDispatcher:
|
|||||||
|
|
||||||
def deepep_permute(
|
def deepep_permute(
|
||||||
self,
|
self,
|
||||||
topk_ids,
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
num_experts,
|
fp8_dtype=None,
|
||||||
top_k,
|
use_fp8_w8a8=False,
|
||||||
use_fp8_w8a8,
|
use_block_quant=False,
|
||||||
use_block_quant,
|
|
||||||
fp8_dtype,
|
|
||||||
):
|
):
|
||||||
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||||
topk_ids, num_experts
|
self.topk_idx, self.num_experts
|
||||||
)
|
)
|
||||||
num_total_tokens = reorder_topk_ids.numel()
|
num_total_tokens = reorder_topk_ids.numel()
|
||||||
gateup_input = torch.empty(
|
gateup_input = torch.empty(
|
||||||
@@ -254,9 +166,9 @@ class DeepEPDispatcher:
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
gateup_input,
|
gateup_input,
|
||||||
src2dst,
|
src2dst,
|
||||||
topk_ids,
|
self.topk_idx,
|
||||||
None,
|
None,
|
||||||
top_k,
|
self.router_topk,
|
||||||
hidden_states.shape[1],
|
hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
BLOCK_SIZE=512,
|
||||||
)
|
)
|
||||||
@@ -302,13 +214,21 @@ class DeepEPDispatcher:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.recv_expert_count = recv_expert_count
|
self.recv_expert_count = recv_expert_count
|
||||||
tokens_per_expert = self.get_number_of_tokens_per_expert()
|
|
||||||
self.handle = handle
|
self.handle = handle
|
||||||
self.topk_idx = topk_idx
|
self.topk_idx = topk_idx
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states)
|
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
||||||
return hidden_states, topk_idx, topk_weights, tokens_per_expert
|
hidden_states, fp8_dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
return hidden_states, reorder_topk_ids, seg_indptr
|
||||||
|
|
||||||
def dispatch_normal(
|
def dispatch_normal(
|
||||||
self,
|
self,
|
||||||
@@ -427,10 +347,29 @@ class DeepEPDispatcher:
|
|||||||
# Todo: enable low latency combine
|
# Todo: enable low latency combine
|
||||||
if True: # not forward_mode.is_decode():
|
if True: # not forward_mode.is_decode():
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
hidden_states = self.get_restored_hidden_states_by_experts(
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
hidden_states
|
output = torch.empty(
|
||||||
|
(num_tokens, hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
hidden_states, event = self.combine_normal(hidden_states, self.handle)
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
|
hidden_states,
|
||||||
|
output,
|
||||||
|
self.src2dst,
|
||||||
|
self.topk_idx,
|
||||||
|
self.topk_weights,
|
||||||
|
self.router_topk,
|
||||||
|
hidden_states.shape[1],
|
||||||
|
BLOCK_SIZE=512,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch.zeros(
|
||||||
|
(0, hidden_states.shape[1]),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states, event = self.combine_normal(output, self.handle)
|
||||||
else:
|
else:
|
||||||
hidden_states, event, hook = self.combine_low_latency(
|
hidden_states, event, hook = self.combine_low_latency(
|
||||||
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
||||||
@@ -467,67 +406,3 @@ class DeepEPDispatcher:
|
|||||||
)
|
)
|
||||||
# hook()
|
# hook()
|
||||||
return combined_hidden_states, event_overlap, hook
|
return combined_hidden_states, event_overlap, hook
|
||||||
|
|
||||||
def _indices_to_multihot(self, indices, probs):
|
|
||||||
batch_size = indices.shape[0]
|
|
||||||
multihot_routing_map = torch.zeros(
|
|
||||||
(batch_size, self.num_local_experts),
|
|
||||||
dtype=torch.long,
|
|
||||||
device=indices.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
multihot_probs = torch.zeros(
|
|
||||||
(batch_size, self.num_local_experts),
|
|
||||||
dtype=torch.float,
|
|
||||||
device=indices.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask = indices != -1
|
|
||||||
valid_indices = indices[mask]
|
|
||||||
row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave(
|
|
||||||
mask.sum(dim=1)
|
|
||||||
)
|
|
||||||
multihot_routing_map[row_indices, valid_indices] = 1
|
|
||||||
multihot_probs[row_indices, valid_indices] = probs[mask]
|
|
||||||
return multihot_routing_map.bool(), multihot_probs
|
|
||||||
|
|
||||||
def get_dispached_metadata(self) -> torch.Tensor:
|
|
||||||
return self.topk_idx, self.topk_weights
|
|
||||||
|
|
||||||
def get_number_of_tokens_per_expert(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Get the number of tokens per expert.
|
|
||||||
"""
|
|
||||||
return self.tokens_per_expert
|
|
||||||
|
|
||||||
def get_permuted_hidden_states_by_experts(
|
|
||||||
self, hidden_states: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
self.dispatched_routing_map, self.topk_weights = self._indices_to_multihot(
|
|
||||||
self.topk_idx, self.topk_weights
|
|
||||||
)
|
|
||||||
self.hidden_shape_before_permute = hidden_states.shape
|
|
||||||
hidden_states, self.reversed_mapping_for_combine = permute(
|
|
||||||
hidden_states,
|
|
||||||
self.dispatched_routing_map,
|
|
||||||
num_out_tokens=self.tokens_per_expert.sum(),
|
|
||||||
fused=self.permute_fusion,
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def get_restored_hidden_states_by_experts(
|
|
||||||
self, hidden_states: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
assert (
|
|
||||||
self.topk_weights.dtype == torch.float32
|
|
||||||
), "DeepEP only supports float32 probs"
|
|
||||||
hidden_states = unpermute(
|
|
||||||
hidden_states,
|
|
||||||
self.reversed_mapping_for_combine,
|
|
||||||
restore_shape=self.hidden_shape_before_permute,
|
|
||||||
routing_map=self.dispatched_routing_map,
|
|
||||||
probs=self.topk_weights,
|
|
||||||
fused=self.permute_fusion,
|
|
||||||
)
|
|
||||||
return hidden_states.to(input_dtype)
|
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
)
|
)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = (
|
recv_hidden_states, reorder_topk_ids, seg_indptr = (
|
||||||
self.deepep_dispatcher.dispatch(
|
self.deepep_dispatcher.dispatch(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
@@ -306,7 +306,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
self.experts(
|
self.experts(
|
||||||
hidden_states=recv_hidden_states,
|
hidden_states=recv_hidden_states,
|
||||||
tokens_per_expert=tokens_per_expert,
|
reorder_topk_ids=reorder_topk_ids,
|
||||||
|
seg_indptr=seg_indptr,
|
||||||
forward_mode=forward_mode,
|
forward_mode=forward_mode,
|
||||||
)
|
)
|
||||||
* self.routed_scaling_factor
|
* self.routed_scaling_factor
|
||||||
|
|||||||
Reference in New Issue
Block a user