diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2141a500f..9b44186be 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -51,7 +51,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -83,6 +83,8 @@ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.operations import execute_operations +from sglang.srt.operations_strategy import compute_layer_operations from sglang.srt.utils import ( BumpAllocator, DeepEPMode, @@ -304,99 +306,123 @@ class DeepseekV2MoE(nn.Module): def _enable_deepep_moe(self): return global_server_args_dict["enable_deepep_moe"] - def forward( - self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None - ) -> torch.Tensor: - forward_mode = forward_batch.forward_mode + def op_gate(self, state): if (not self._enable_deepep_moe) or is_non_idle_and_non_empty( - forward_mode, hidden_states + state.forward_batch.forward_mode, state.hidden_states_mlp_input ): # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) + state.router_logits = self.gate(state.hidden_states_mlp_input) else: - router_logits = None + state.router_logits = None + def op_shared_experts(self, state): if (self.n_share_experts_fusion == 0) and ( (not self._enable_deepep_moe) - or is_non_idle_and_non_empty(forward_mode, hidden_states) + or is_non_idle_and_non_empty( + state.forward_batch.forward_mode, state.hidden_states_mlp_input + ) ): - shared_output = self.shared_experts(hidden_states) + state.shared_output = self.shared_experts(state.hidden_states_mlp_input) else: - shared_output = None + state.shared_output = None - if self._enable_deepep_moe and (router_logits is not None): - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - num_token_non_padded=forward_batch.num_token_non_padded, - ) - else: - topk_idx = torch.full( - (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device - ) - topk_weights = torch.empty( - (0, self.top_k), dtype=torch.float32, device=hidden_states.device - ) - - if self._enable_deepep_moe and (self.ep_size > 1): - # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value - ( - hidden_states, - topk_idx, - topk_weights, - reorder_topk_ids, - num_recv_tokens_per_expert, - seg_indptr, - masked_m, - expected_m, - ) = self.deepep_dispatcher.dispatch( - hidden_states, - topk_idx, - topk_weights, - forward_mode=forward_mode, - ) + def op_select_experts(self, state): + router_logits = state.router_logits + hidden_states = state.hidden_states_mlp_input if self._enable_deepep_moe: - final_hidden_states = self.experts( - hidden_states=hidden_states, - topk_idx=topk_idx, - topk_weights=topk_weights, - reorder_topk_ids=reorder_topk_ids, - seg_indptr=seg_indptr, - masked_m=masked_m, - expected_m=expected_m, - num_recv_tokens_per_expert=num_recv_tokens_per_expert, - forward_mode=forward_mode, - ) - else: - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + if router_logits is not None: + state.topk_weights_local, state.topk_idx_local = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + else: + state.topk_idx_local = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + state.topk_weights_local = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + + def op_dispatch_a(self, state): + if self._enable_deepep_moe and (self.ep_size > 1): + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value + self.deepep_dispatcher.dispatch_a( + hidden_states=state.pop("hidden_states_mlp_input"), + topk_idx=state.pop("topk_idx_local"), + topk_weights=state.pop("topk_weights_local"), + forward_mode=state.forward_batch.forward_mode, ) + def op_dispatch_b(self, state): if self._enable_deepep_moe and (self.ep_size > 1): - final_hidden_states = self.deepep_dispatcher.combine( - final_hidden_states, - topk_idx, - topk_weights, - forward_mode, + ( + state.hidden_states_experts_input, + state.topk_idx_dispatched, + state.topk_weights_dispatched, + state.reorder_topk_ids, + state.num_recv_tokens_per_expert, + state.seg_indptr, + state.masked_m, + state.expected_m, + ) = self.deepep_dispatcher.dispatch_b() + + def op_experts(self, state): + if self._enable_deepep_moe: + state.pop("router_logits") + state.hidden_states_experts_output = self.experts( + hidden_states=state.pop("hidden_states_experts_input"), + topk_idx=state.topk_idx_dispatched, + topk_weights=state.topk_weights_dispatched, + reorder_topk_ids=state.pop("reorder_topk_ids"), + seg_indptr=state.pop("seg_indptr"), + masked_m=state.pop("masked_m"), + expected_m=state.pop("expected_m"), + num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"), + forward_mode=state.forward_batch.forward_mode, ) + else: + state.hidden_states_experts_output = self.experts( + hidden_states=state.pop("hidden_states_mlp_input"), + router_logits=state.pop("router_logits"), + ) + + def op_combine_a(self, state): + if self._enable_deepep_moe and (self.ep_size > 1): + self.deepep_dispatcher.combine_a( + state.pop("hidden_states_experts_output"), + topk_idx=state.pop("topk_idx_dispatched"), + topk_weights=state.pop("topk_weights_dispatched"), + forward_mode=state.forward_batch.forward_mode, + ) + + def op_combine_b(self, state): + if self._enable_deepep_moe and (self.ep_size > 1): + state.hidden_states_after_combine = self.deepep_dispatcher.combine_b() + + def op_output(self, state): + final_hidden_states = ( + state.pop("hidden_states_after_combine") + if self._enable_deepep_moe + else state.pop("hidden_states_experts_output") + ) final_hidden_states *= self.routed_scaling_factor - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + if (s := state.pop("shared_output")) is not None: + final_hidden_states = final_hidden_states + s if (not self._enable_deepep_moe) and (self.tp_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - return final_hidden_states + state.hidden_states_mlp_output = final_hidden_states def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: @@ -1197,27 +1223,77 @@ class DeepseekV2DecoderLayer(nn.Module): residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, ) -> torch.Tensor: - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + return execute_operations( + inputs=dict( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + residual=residual, + zero_allocator=zero_allocator, + ), + operations=compute_layer_operations(self), ) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - zero_allocator=zero_allocator, + def op_comm_prepare_attn( + self, + state, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, + ): + state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = ( + self.layer_communicator.prepare_attn( + hidden_states, residual, state.forward_batch + ) + ) + state.update( + dict( + forward_batch=forward_batch, + positions=positions, + zero_allocator=zero_allocator, + ) ) - hidden_states, residual = self.layer_communicator.prepare_mlp( - hidden_states, residual, forward_batch + def op_attn(self, state): + state.hidden_states_after_attn = self.self_attn( + positions=state.positions, + hidden_states=state.pop("hidden_states_after_comm_pre_attn"), + forward_batch=state.forward_batch, + zero_allocator=state.zero_allocator, ) - hidden_states = self.mlp(hidden_states, forward_batch) + def op_comm_prepare_mlp(self, state): + state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( + self.layer_communicator.prepare_mlp( + state.pop("hidden_states_after_attn"), + state.pop("residual_after_input_ln"), + state.forward_batch, + ) + ) + def op_mlp(self, state): + hidden_states = state.pop("hidden_states_mlp_input") + if not ( + enable_moe_dense_fully_dp() + and (not self.is_layer_sparse) + and hidden_states.shape[0] == 0 + ): + state.hidden_states_mlp_output = self.mlp( + hidden_states, state.forward_batch.forward_mode + ) + else: + state.hidden_states_mlp_output = hidden_states + + def op_comm_postprocess_layer(self, state): hidden_states, residual = self.layer_communicator.postprocess_layer( - hidden_states, residual, forward_batch + state.pop("hidden_states_mlp_output"), + state.pop("residual_after_comm_pre_mlp"), + state.forward_batch, ) + state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"}) return hidden_states, residual diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py new file mode 100644 index 000000000..0ef9fb9c4 --- /dev/null +++ b/python/sglang/srt/operations.py @@ -0,0 +1,154 @@ +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generator, List, Sequence, Union + +import torch + +_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0"))) + +if _ENABLE_PROFILE: + import nvtx + + +def execute_operations(inputs, operations): + stages = _convert_operations_to_stages(decorate_operations(operations)) + executor = _StageExecutor("primary", stages, inputs=inputs) + for _ in range(executor.num_stages): + executor.next() + assert executor.done + return executor.output + + +class YieldOperation: + pass + + +@dataclass +class ExecutionOperation: + debug_name: str + fn: Callable + + +Operation = Union[YieldOperation, ExecutionOperation, Callable] +Stage = List[ExecutionOperation] + + +class _StageExecutor: + def __init__(self, debug_name: str, stages: List[Stage], inputs): + self._debug_name = debug_name + self._stages = stages + self._index = 0 + self._stage_state = _StateDict() + self._stage_output = inputs + + def next(self): + assert not self.done + + stage = self._stages[self._index] + + with _annotate_region(debug_name=f"{self._debug_name}{self._index}"): + for op in stage: + with _annotate_region(debug_name=op.debug_name): + self._stage_output = op.fn( + state=self._stage_state, + **( + self._stage_output if self._stage_output is not None else {} + ), + ) + + self._index += 1 + + @property + def output(self): + assert self.done + return self._stage_output + + @property + def done(self): + return self._index >= self.num_stages + + @property + def num_stages(self): + return len(self._stages) + + +@contextmanager +def _annotate_region(debug_name): + if _ENABLE_PROFILE: + with torch.autograd.profiler.record_function(debug_name): + with nvtx.annotate(debug_name): + yield + else: + yield + + +class _StateDict: + def __init__(self): + self._data = {} + + def __setattr__(self, key, value): + if key == "_data": + super().__setattr__(key, value) + return + assert ( + key not in self._data + ), f"`{key}` already exist, are you sure you want to override it?" + self._data[key] = value + + def __getattr__(self, item): + return self._data[item] + + def __delattr__(self, item): + del self._data[item] + + def pop(self, item): + return self._data.pop(item) + + def update(self, values: Dict[str, Any]): + for k, v in values.items(): + setattr(self, k, v) + + def clear(self, expect_keys: Sequence[str]): + if set(self._data.keys()) != set(expect_keys): + raise Exception( + f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}" + ) + + self._data.clear() + + +def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]: + operation_chunks = list( + _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation)) + ) + assert all(len(chunk) > 0 for chunk in operation_chunks) + return operation_chunks + + +def _chunk_by_separator( + items: List[Any], is_separator: Callable[[Any], bool] +) -> Generator[List[Any], None, None]: + pending_items = [] + for item in items: + if is_separator(item): + yield pending_items + pending_items = [] + else: + pending_items.append(item) + if len(pending_items) > 0: + yield pending_items + + +def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""): + return [_decorate_operation(op, debug_name_prefix) for op in operations] + + +def _decorate_operation(operation: Operation, debug_name_prefix: str): + if isinstance(operation, YieldOperation): + return operation + return ExecutionOperation( + debug_name=debug_name_prefix + + getattr(operation, "__name__", "unknown").replace("op_", ""), + fn=operation, + ) diff --git a/python/sglang/srt/operations_strategy.py b/python/sglang/srt/operations_strategy.py new file mode 100644 index 000000000..c30a53ac9 --- /dev/null +++ b/python/sglang/srt/operations_strategy.py @@ -0,0 +1,31 @@ +import torch + + +def compute_layer_operations( + layer: torch.nn.Module, +): + if not layer.is_layer_sparse: + return [ + layer.op_comm_prepare_attn, + layer.op_attn, + layer.op_comm_prepare_mlp, + layer.op_mlp, + layer.op_comm_postprocess_layer, + ] + + # Will add TBO operation orders here + return [ + layer.op_comm_prepare_attn, + layer.op_attn, + layer.op_comm_prepare_mlp, + layer.mlp.op_gate, + layer.mlp.op_shared_experts, + layer.mlp.op_select_experts, + layer.mlp.op_dispatch_a, + layer.mlp.op_dispatch_b, + layer.mlp.op_experts, + layer.mlp.op_combine_a, + layer.mlp.op_combine_b, + layer.mlp.op_output, + layer.op_comm_postprocess_layer, + ]