diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 4303fcd4e..8184a9305 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,9 +1,11 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch +import triton +import triton.language as tl from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.moe import ( @@ -31,7 +33,15 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu +from sglang.srt.offloader import get_offloader +from sglang.srt.utils import ( + ceil_div, + dispose_tensor, + get_bool_env_var, + is_cuda, + is_hip, + is_npu, +) if TYPE_CHECKING: from sglang.srt.layers.moe.token_dispatcher import ( @@ -535,6 +545,24 @@ class DeepEPMoE(EPMoE): N = self.w13_weight.size(1) scale_block_size = 128 + # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass) + w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + w2_weight_fp8 = ( + self.w2_weight, + ( + self.w2_weight_scale_inv + if self.use_block_quant + else self.w2_weight_scale + ), + ) + hidden_states_fp8_shape = hidden_states_fp8.shape hidden_states_fp8_device = hidden_states_fp8.device hidden_states_fp8_dtype = hidden_states_fp8.dtype @@ -565,12 +593,17 @@ class DeepEPMoE(EPMoE): ) output_index = torch.empty_like(topk_idx) - num_recv_tokens_per_expert_gpu = torch.tensor( - num_recv_tokens_per_expert, - dtype=torch.int32, - pin_memory=True, - device="cpu", - ).cuda(non_blocking=True) + if get_offloader().forbid_copy_engine_usage: + num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce( + num_recv_tokens_per_expert + ) + else: + num_recv_tokens_per_expert_gpu = torch.tensor( + num_recv_tokens_per_expert, + dtype=torch.int32, + pin_memory=True, + device="cpu", + ).cuda(non_blocking=True) expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu) ep_scatter( @@ -595,7 +628,7 @@ class DeepEPMoE(EPMoE): if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0: input_tensor[1] = tma_align_input_scale(input_tensor[1]) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( - input_tensor, self.w13_weight_fp8, gateup_output, m_indices + input_tensor, w13_weight_fp8, gateup_output, m_indices ) del input_tensor down_input = torch.empty( @@ -625,7 +658,7 @@ class DeepEPMoE(EPMoE): down_input_scale = tma_align_input_scale(down_input_scale) deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig( (down_input_fp8, down_input_scale), - self.w2_weight_fp8, + w2_weight_fp8, down_output, m_indices, ) @@ -885,3 +918,12 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None): if get_moe_expert_parallel_world_size() > 1: return EPMoE return FusedMoE + + +def copy_list_to_gpu_no_ce(arr: List[int]): + from sgl_kernel.elementwise import copy_to_gpu_no_ce + + tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu") + tensor_gpu = torch.empty_like(tensor_cpu, device="cuda") + copy_to_gpu_no_ce(tensor_cpu, tensor_gpu) + return tensor_gpu diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 998423b86..b09b80907 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple import torch +from sglang.srt import offloader from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil @@ -417,10 +418,14 @@ def block_quant_dequant( def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size): assert isinstance(weight, torch.nn.Parameter) assert isinstance(weight_scale_inv, torch.nn.Parameter) - weight.data, weight_scale_inv.data = _requant_weight_ue8m0( - weight, weight_scale_inv, weight_block_size + + new_weight, new_weight_scale_inv = _requant_weight_ue8m0( + weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size ) + offloader.update_param(weight, new_weight) + weight_scale_inv.data = new_weight_scale_inv + def _requant_weight_ue8m0( weight: torch.Tensor, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 7c535a6ef..478fe9ed2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2244,8 +2244,15 @@ class DeepseekV2Model(nn.Module): [ "w13_weight", "w2_weight", - "w13_blockscale_swizzled", - "w2_blockscale_swizzled", + # only for nvfp4 + *( + [ + "w13_blockscale_swizzled", + "w2_blockscale_swizzled", + ] + if hasattr(module, "w13_blockscale_swizzled") + else [] + ), ] if isinstance(module, FusedMoE) else [] diff --git a/python/sglang/srt/offloader.py b/python/sglang/srt/offloader.py index aea7d7f23..0adddf5a6 100644 --- a/python/sglang/srt/offloader.py +++ b/python/sglang/srt/offloader.py @@ -38,6 +38,10 @@ class BaseOffloader(ABC): def post_init(self): pass + @property + def forbid_copy_engine_usage(self): + return False + class NoopOffloader(BaseOffloader): pass @@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader): for i in range(self.prefetch_step): self.offloaders[i].start_onload() + @property + def forbid_copy_engine_usage(self): + return self.mode == "cpu" + def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step): def _on_forward_end(): @@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader): return self.shm_cpu_data.to("cuda", non_blocking=True) +def update_param(param, new_tensor): + """Update parameter while keeping properties needed by Offloader (e.g. pinned host memory).""" + + if param.device == new_tensor.device: + param.data = new_tensor + else: + assert param.device == torch.device( + "cpu" + ), f"{param.device=} {new_tensor.device=}" + param.data = _create_cpu_data(new_tensor, pin_memory=True) + + def _move_param_to_cpu(param, pin_memory: bool): + param.data = _create_cpu_data(param.data, pin_memory=pin_memory) + + +def _create_cpu_data(data, pin_memory: bool): cpu_data = _empty_strided_like( - param.data, + data, device="cpu", pin_memory=pin_memory, ) - cpu_data.copy_(param.data) - param.data = cpu_data + cpu_data.copy_(data) + return cpu_data def _move_param_to_meta(module, param_name):