Support offloading in fp8 (#9948)
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.moe import (
|
||||
@@ -31,7 +33,15 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||
from sglang.srt.offloader import get_offloader
|
||||
from sglang.srt.utils import (
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
@@ -535,6 +545,24 @@ class DeepEPMoE(EPMoE):
|
||||
N = self.w13_weight.size(1)
|
||||
scale_block_size = 128
|
||||
|
||||
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
||||
w13_weight_fp8 = (
|
||||
self.w13_weight,
|
||||
(
|
||||
self.w13_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w13_weight_scale
|
||||
),
|
||||
)
|
||||
w2_weight_fp8 = (
|
||||
self.w2_weight,
|
||||
(
|
||||
self.w2_weight_scale_inv
|
||||
if self.use_block_quant
|
||||
else self.w2_weight_scale
|
||||
),
|
||||
)
|
||||
|
||||
hidden_states_fp8_shape = hidden_states_fp8.shape
|
||||
hidden_states_fp8_device = hidden_states_fp8.device
|
||||
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
||||
@@ -565,12 +593,17 @@ class DeepEPMoE(EPMoE):
|
||||
)
|
||||
output_index = torch.empty_like(topk_idx)
|
||||
|
||||
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||
num_recv_tokens_per_expert,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device="cpu",
|
||||
).cuda(non_blocking=True)
|
||||
if get_offloader().forbid_copy_engine_usage:
|
||||
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
||||
num_recv_tokens_per_expert
|
||||
)
|
||||
else:
|
||||
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||
num_recv_tokens_per_expert,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True,
|
||||
device="cpu",
|
||||
).cuda(non_blocking=True)
|
||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||
|
||||
ep_scatter(
|
||||
@@ -595,7 +628,7 @@ class DeepEPMoE(EPMoE):
|
||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
input_tensor, w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
del input_tensor
|
||||
down_input = torch.empty(
|
||||
@@ -625,7 +658,7 @@ class DeepEPMoE(EPMoE):
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
w2_weight_fp8,
|
||||
down_output,
|
||||
m_indices,
|
||||
)
|
||||
@@ -885,3 +918,12 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
return EPMoE
|
||||
return FusedMoE
|
||||
|
||||
|
||||
def copy_list_to_gpu_no_ce(arr: List[int]):
|
||||
from sgl_kernel.elementwise import copy_to_gpu_no_ce
|
||||
|
||||
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
|
||||
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
|
||||
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
|
||||
return tensor_gpu
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt import offloader
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||
@@ -417,10 +418,14 @@ def block_quant_dequant(
|
||||
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
||||
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
||||
weight, weight_scale_inv, weight_block_size
|
||||
|
||||
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
|
||||
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
offloader.update_param(weight, new_weight)
|
||||
weight_scale_inv.data = new_weight_scale_inv
|
||||
|
||||
|
||||
def _requant_weight_ue8m0(
|
||||
weight: torch.Tensor,
|
||||
|
||||
@@ -2244,8 +2244,15 @@ class DeepseekV2Model(nn.Module):
|
||||
[
|
||||
"w13_weight",
|
||||
"w2_weight",
|
||||
"w13_blockscale_swizzled",
|
||||
"w2_blockscale_swizzled",
|
||||
# only for nvfp4
|
||||
*(
|
||||
[
|
||||
"w13_blockscale_swizzled",
|
||||
"w2_blockscale_swizzled",
|
||||
]
|
||||
if hasattr(module, "w13_blockscale_swizzled")
|
||||
else []
|
||||
),
|
||||
]
|
||||
if isinstance(module, FusedMoE)
|
||||
else []
|
||||
|
||||
@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
|
||||
def post_init(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def forbid_copy_engine_usage(self):
|
||||
return False
|
||||
|
||||
|
||||
class NoopOffloader(BaseOffloader):
|
||||
pass
|
||||
@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
|
||||
for i in range(self.prefetch_step):
|
||||
self.offloaders[i].start_onload()
|
||||
|
||||
@property
|
||||
def forbid_copy_engine_usage(self):
|
||||
return self.mode == "cpu"
|
||||
|
||||
|
||||
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
|
||||
def _on_forward_end():
|
||||
@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
|
||||
return self.shm_cpu_data.to("cuda", non_blocking=True)
|
||||
|
||||
|
||||
def update_param(param, new_tensor):
|
||||
"""Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
|
||||
|
||||
if param.device == new_tensor.device:
|
||||
param.data = new_tensor
|
||||
else:
|
||||
assert param.device == torch.device(
|
||||
"cpu"
|
||||
), f"{param.device=} {new_tensor.device=}"
|
||||
param.data = _create_cpu_data(new_tensor, pin_memory=True)
|
||||
|
||||
|
||||
def _move_param_to_cpu(param, pin_memory: bool):
|
||||
param.data = _create_cpu_data(param.data, pin_memory=pin_memory)
|
||||
|
||||
|
||||
def _create_cpu_data(data, pin_memory: bool):
|
||||
cpu_data = _empty_strided_like(
|
||||
param.data,
|
||||
data,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(param.data)
|
||||
param.data = cpu_data
|
||||
cpu_data.copy_(data)
|
||||
return cpu_data
|
||||
|
||||
|
||||
def _move_param_to_meta(module, param_name):
|
||||
|
||||
Reference in New Issue
Block a user