Refactor DeepSeek logic into atomic operations (#6326)
This commit is contained in:
@@ -51,7 +51,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -83,6 +83,8 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.operations import execute_operations
|
||||
from sglang.srt.operations_strategy import compute_layer_operations
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
@@ -304,99 +306,123 @@ class DeepseekV2MoE(nn.Module):
|
||||
def _enable_deepep_moe(self):
|
||||
return global_server_args_dict["enable_deepep_moe"]
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
forward_mode = forward_batch.forward_mode
|
||||
def op_gate(self, state):
|
||||
if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
|
||||
forward_mode, hidden_states
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
state.router_logits = self.gate(state.hidden_states_mlp_input)
|
||||
else:
|
||||
router_logits = None
|
||||
state.router_logits = None
|
||||
|
||||
def op_shared_experts(self, state):
|
||||
if (self.n_share_experts_fusion == 0) and (
|
||||
(not self._enable_deepep_moe)
|
||||
or is_non_idle_and_non_empty(forward_mode, hidden_states)
|
||||
or is_non_idle_and_non_empty(
|
||||
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
||||
)
|
||||
):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
|
||||
else:
|
||||
shared_output = None
|
||||
state.shared_output = None
|
||||
|
||||
if self._enable_deepep_moe and (router_logits is not None):
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
topk_weights = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
reorder_topk_ids,
|
||||
num_recv_tokens_per_expert,
|
||||
seg_indptr,
|
||||
masked_m,
|
||||
expected_m,
|
||||
) = self.deepep_dispatcher.dispatch(
|
||||
hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
def op_select_experts(self, state):
|
||||
router_logits = state.router_logits
|
||||
hidden_states = state.hidden_states_mlp_input
|
||||
|
||||
if self._enable_deepep_moe:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
reorder_topk_ids=reorder_topk_ids,
|
||||
seg_indptr=seg_indptr,
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
||||
forward_mode=forward_mode,
|
||||
)
|
||||
else:
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
if router_logits is not None:
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
state.topk_idx_local = torch.full(
|
||||
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
||||
)
|
||||
state.topk_weights_local = torch.empty(
|
||||
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
|
||||
def op_dispatch_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
||||
self.deepep_dispatcher.dispatch_a(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
topk_idx=state.pop("topk_idx_local"),
|
||||
topk_weights=state.pop("topk_weights_local"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
|
||||
def op_dispatch_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
final_hidden_states = self.deepep_dispatcher.combine(
|
||||
final_hidden_states,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
forward_mode,
|
||||
(
|
||||
state.hidden_states_experts_input,
|
||||
state.topk_idx_dispatched,
|
||||
state.topk_weights_dispatched,
|
||||
state.reorder_topk_ids,
|
||||
state.num_recv_tokens_per_expert,
|
||||
state.seg_indptr,
|
||||
state.masked_m,
|
||||
state.expected_m,
|
||||
) = self.deepep_dispatcher.dispatch_b()
|
||||
|
||||
def op_experts(self, state):
|
||||
if self._enable_deepep_moe:
|
||||
state.pop("router_logits")
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_experts_input"),
|
||||
topk_idx=state.topk_idx_dispatched,
|
||||
topk_weights=state.topk_weights_dispatched,
|
||||
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
||||
seg_indptr=state.pop("seg_indptr"),
|
||||
masked_m=state.pop("masked_m"),
|
||||
expected_m=state.pop("expected_m"),
|
||||
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
else:
|
||||
state.hidden_states_experts_output = self.experts(
|
||||
hidden_states=state.pop("hidden_states_mlp_input"),
|
||||
router_logits=state.pop("router_logits"),
|
||||
)
|
||||
|
||||
def op_combine_a(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
self.deepep_dispatcher.combine_a(
|
||||
state.pop("hidden_states_experts_output"),
|
||||
topk_idx=state.pop("topk_idx_dispatched"),
|
||||
topk_weights=state.pop("topk_weights_dispatched"),
|
||||
forward_mode=state.forward_batch.forward_mode,
|
||||
)
|
||||
|
||||
def op_combine_b(self, state):
|
||||
if self._enable_deepep_moe and (self.ep_size > 1):
|
||||
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
|
||||
|
||||
def op_output(self, state):
|
||||
final_hidden_states = (
|
||||
state.pop("hidden_states_after_combine")
|
||||
if self._enable_deepep_moe
|
||||
else state.pop("hidden_states_experts_output")
|
||||
)
|
||||
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if (s := state.pop("shared_output")) is not None:
|
||||
final_hidden_states = final_hidden_states + s
|
||||
|
||||
if (not self._enable_deepep_moe) and (self.tp_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
state.hidden_states_mlp_output = final_hidden_states
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
@@ -1197,27 +1223,77 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
) -> torch.Tensor:
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
return execute_operations(
|
||||
inputs=dict(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
residual=residual,
|
||||
zero_allocator=zero_allocator,
|
||||
),
|
||||
operations=compute_layer_operations(self),
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
zero_allocator=zero_allocator,
|
||||
def op_comm_prepare_attn(
|
||||
self,
|
||||
state,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
):
|
||||
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
||||
self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, state.forward_batch
|
||||
)
|
||||
)
|
||||
state.update(
|
||||
dict(
|
||||
forward_batch=forward_batch,
|
||||
positions=positions,
|
||||
zero_allocator=zero_allocator,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
def op_attn(self, state):
|
||||
state.hidden_states_after_attn = self.self_attn(
|
||||
positions=state.positions,
|
||||
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
||||
forward_batch=state.forward_batch,
|
||||
zero_allocator=state.zero_allocator,
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
def op_comm_prepare_mlp(self, state):
|
||||
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
||||
self.layer_communicator.prepare_mlp(
|
||||
state.pop("hidden_states_after_attn"),
|
||||
state.pop("residual_after_input_ln"),
|
||||
state.forward_batch,
|
||||
)
|
||||
)
|
||||
|
||||
def op_mlp(self, state):
|
||||
hidden_states = state.pop("hidden_states_mlp_input")
|
||||
if not (
|
||||
enable_moe_dense_fully_dp()
|
||||
and (not self.is_layer_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
state.hidden_states_mlp_output = self.mlp(
|
||||
hidden_states, state.forward_batch.forward_mode
|
||||
)
|
||||
else:
|
||||
state.hidden_states_mlp_output = hidden_states
|
||||
|
||||
def op_comm_postprocess_layer(self, state):
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
state.pop("hidden_states_mlp_output"),
|
||||
state.pop("residual_after_comm_pre_mlp"),
|
||||
state.forward_batch,
|
||||
)
|
||||
|
||||
state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
|
||||
154
python/sglang/srt/operations.py
Normal file
154
python/sglang/srt/operations.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
|
||||
|
||||
if _ENABLE_PROFILE:
|
||||
import nvtx
|
||||
|
||||
|
||||
def execute_operations(inputs, operations):
|
||||
stages = _convert_operations_to_stages(decorate_operations(operations))
|
||||
executor = _StageExecutor("primary", stages, inputs=inputs)
|
||||
for _ in range(executor.num_stages):
|
||||
executor.next()
|
||||
assert executor.done
|
||||
return executor.output
|
||||
|
||||
|
||||
class YieldOperation:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionOperation:
|
||||
debug_name: str
|
||||
fn: Callable
|
||||
|
||||
|
||||
Operation = Union[YieldOperation, ExecutionOperation, Callable]
|
||||
Stage = List[ExecutionOperation]
|
||||
|
||||
|
||||
class _StageExecutor:
|
||||
def __init__(self, debug_name: str, stages: List[Stage], inputs):
|
||||
self._debug_name = debug_name
|
||||
self._stages = stages
|
||||
self._index = 0
|
||||
self._stage_state = _StateDict()
|
||||
self._stage_output = inputs
|
||||
|
||||
def next(self):
|
||||
assert not self.done
|
||||
|
||||
stage = self._stages[self._index]
|
||||
|
||||
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
||||
for op in stage:
|
||||
with _annotate_region(debug_name=op.debug_name):
|
||||
self._stage_output = op.fn(
|
||||
state=self._stage_state,
|
||||
**(
|
||||
self._stage_output if self._stage_output is not None else {}
|
||||
),
|
||||
)
|
||||
|
||||
self._index += 1
|
||||
|
||||
@property
|
||||
def output(self):
|
||||
assert self.done
|
||||
return self._stage_output
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
return self._index >= self.num_stages
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
return len(self._stages)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _annotate_region(debug_name):
|
||||
if _ENABLE_PROFILE:
|
||||
with torch.autograd.profiler.record_function(debug_name):
|
||||
with nvtx.annotate(debug_name):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
class _StateDict:
|
||||
def __init__(self):
|
||||
self._data = {}
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_data":
|
||||
super().__setattr__(key, value)
|
||||
return
|
||||
assert (
|
||||
key not in self._data
|
||||
), f"`{key}` already exist, are you sure you want to override it?"
|
||||
self._data[key] = value
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self._data[item]
|
||||
|
||||
def __delattr__(self, item):
|
||||
del self._data[item]
|
||||
|
||||
def pop(self, item):
|
||||
return self._data.pop(item)
|
||||
|
||||
def update(self, values: Dict[str, Any]):
|
||||
for k, v in values.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def clear(self, expect_keys: Sequence[str]):
|
||||
if set(self._data.keys()) != set(expect_keys):
|
||||
raise Exception(
|
||||
f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}"
|
||||
)
|
||||
|
||||
self._data.clear()
|
||||
|
||||
|
||||
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
|
||||
operation_chunks = list(
|
||||
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
|
||||
)
|
||||
assert all(len(chunk) > 0 for chunk in operation_chunks)
|
||||
return operation_chunks
|
||||
|
||||
|
||||
def _chunk_by_separator(
|
||||
items: List[Any], is_separator: Callable[[Any], bool]
|
||||
) -> Generator[List[Any], None, None]:
|
||||
pending_items = []
|
||||
for item in items:
|
||||
if is_separator(item):
|
||||
yield pending_items
|
||||
pending_items = []
|
||||
else:
|
||||
pending_items.append(item)
|
||||
if len(pending_items) > 0:
|
||||
yield pending_items
|
||||
|
||||
|
||||
def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
||||
return [_decorate_operation(op, debug_name_prefix) for op in operations]
|
||||
|
||||
|
||||
def _decorate_operation(operation: Operation, debug_name_prefix: str):
|
||||
if isinstance(operation, YieldOperation):
|
||||
return operation
|
||||
return ExecutionOperation(
|
||||
debug_name=debug_name_prefix
|
||||
+ getattr(operation, "__name__", "unknown").replace("op_", ""),
|
||||
fn=operation,
|
||||
)
|
||||
31
python/sglang/srt/operations_strategy.py
Normal file
31
python/sglang/srt/operations_strategy.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
|
||||
def compute_layer_operations(
|
||||
layer: torch.nn.Module,
|
||||
):
|
||||
if not layer.is_layer_sparse:
|
||||
return [
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.op_attn,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.op_mlp,
|
||||
layer.op_comm_postprocess_layer,
|
||||
]
|
||||
|
||||
# Will add TBO operation orders here
|
||||
return [
|
||||
layer.op_comm_prepare_attn,
|
||||
layer.op_attn,
|
||||
layer.op_comm_prepare_mlp,
|
||||
layer.mlp.op_gate,
|
||||
layer.mlp.op_shared_experts,
|
||||
layer.mlp.op_select_experts,
|
||||
layer.mlp.op_dispatch_a,
|
||||
layer.mlp.op_dispatch_b,
|
||||
layer.mlp.op_experts,
|
||||
layer.mlp.op_combine_a,
|
||||
layer.mlp.op_combine_b,
|
||||
layer.mlp.op_output,
|
||||
layer.op_comm_postprocess_layer,
|
||||
]
|
||||
Reference in New Issue
Block a user