Support offloading in fp8 (#9948)
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||||
from sglang.srt.layers.moe import (
|
from sglang.srt.layers.moe import (
|
||||||
@@ -31,7 +33,15 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
from sglang.srt.offloader import get_offloader
|
||||||
|
from sglang.srt.utils import (
|
||||||
|
ceil_div,
|
||||||
|
dispose_tensor,
|
||||||
|
get_bool_env_var,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
is_npu,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.token_dispatcher import (
|
from sglang.srt.layers.moe.token_dispatcher import (
|
||||||
@@ -535,6 +545,24 @@ class DeepEPMoE(EPMoE):
|
|||||||
N = self.w13_weight.size(1)
|
N = self.w13_weight.size(1)
|
||||||
scale_block_size = 128
|
scale_block_size = 128
|
||||||
|
|
||||||
|
# TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
|
||||||
|
w13_weight_fp8 = (
|
||||||
|
self.w13_weight,
|
||||||
|
(
|
||||||
|
self.w13_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w13_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
w2_weight_fp8 = (
|
||||||
|
self.w2_weight,
|
||||||
|
(
|
||||||
|
self.w2_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w2_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states_fp8_shape = hidden_states_fp8.shape
|
hidden_states_fp8_shape = hidden_states_fp8.shape
|
||||||
hidden_states_fp8_device = hidden_states_fp8.device
|
hidden_states_fp8_device = hidden_states_fp8.device
|
||||||
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
||||||
@@ -565,12 +593,17 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
output_index = torch.empty_like(topk_idx)
|
output_index = torch.empty_like(topk_idx)
|
||||||
|
|
||||||
num_recv_tokens_per_expert_gpu = torch.tensor(
|
if get_offloader().forbid_copy_engine_usage:
|
||||||
num_recv_tokens_per_expert,
|
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
|
||||||
dtype=torch.int32,
|
num_recv_tokens_per_expert
|
||||||
pin_memory=True,
|
)
|
||||||
device="cpu",
|
else:
|
||||||
).cuda(non_blocking=True)
|
num_recv_tokens_per_expert_gpu = torch.tensor(
|
||||||
|
num_recv_tokens_per_expert,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True,
|
||||||
|
device="cpu",
|
||||||
|
).cuda(non_blocking=True)
|
||||||
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
|
||||||
|
|
||||||
ep_scatter(
|
ep_scatter(
|
||||||
@@ -595,7 +628,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
input_tensor, w13_weight_fp8, gateup_output, m_indices
|
||||||
)
|
)
|
||||||
del input_tensor
|
del input_tensor
|
||||||
down_input = torch.empty(
|
down_input = torch.empty(
|
||||||
@@ -625,7 +658,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||||
(down_input_fp8, down_input_scale),
|
(down_input_fp8, down_input_scale),
|
||||||
self.w2_weight_fp8,
|
w2_weight_fp8,
|
||||||
down_output,
|
down_output,
|
||||||
m_indices,
|
m_indices,
|
||||||
)
|
)
|
||||||
@@ -885,3 +918,12 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
|||||||
if get_moe_expert_parallel_world_size() > 1:
|
if get_moe_expert_parallel_world_size() > 1:
|
||||||
return EPMoE
|
return EPMoE
|
||||||
return FusedMoE
|
return FusedMoE
|
||||||
|
|
||||||
|
|
||||||
|
def copy_list_to_gpu_no_ce(arr: List[int]):
|
||||||
|
from sgl_kernel.elementwise import copy_to_gpu_no_ce
|
||||||
|
|
||||||
|
tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
|
||||||
|
tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
|
||||||
|
copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
|
||||||
|
return tensor_gpu
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt import offloader
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||||
@@ -417,10 +418,14 @@ def block_quant_dequant(
|
|||||||
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
||||||
assert isinstance(weight, torch.nn.Parameter)
|
assert isinstance(weight, torch.nn.Parameter)
|
||||||
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
||||||
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
|
||||||
weight, weight_scale_inv, weight_block_size
|
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
|
||||||
|
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
offloader.update_param(weight, new_weight)
|
||||||
|
weight_scale_inv.data = new_weight_scale_inv
|
||||||
|
|
||||||
|
|
||||||
def _requant_weight_ue8m0(
|
def _requant_weight_ue8m0(
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
|
|||||||
@@ -2244,8 +2244,15 @@ class DeepseekV2Model(nn.Module):
|
|||||||
[
|
[
|
||||||
"w13_weight",
|
"w13_weight",
|
||||||
"w2_weight",
|
"w2_weight",
|
||||||
"w13_blockscale_swizzled",
|
# only for nvfp4
|
||||||
"w2_blockscale_swizzled",
|
*(
|
||||||
|
[
|
||||||
|
"w13_blockscale_swizzled",
|
||||||
|
"w2_blockscale_swizzled",
|
||||||
|
]
|
||||||
|
if hasattr(module, "w13_blockscale_swizzled")
|
||||||
|
else []
|
||||||
|
),
|
||||||
]
|
]
|
||||||
if isinstance(module, FusedMoE)
|
if isinstance(module, FusedMoE)
|
||||||
else []
|
else []
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class BaseOffloader(ABC):
|
|||||||
def post_init(self):
|
def post_init(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def forbid_copy_engine_usage(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class NoopOffloader(BaseOffloader):
|
class NoopOffloader(BaseOffloader):
|
||||||
pass
|
pass
|
||||||
@@ -233,6 +237,10 @@ class OffloaderV2(BaseOffloader):
|
|||||||
for i in range(self.prefetch_step):
|
for i in range(self.prefetch_step):
|
||||||
self.offloaders[i].start_onload()
|
self.offloaders[i].start_onload()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def forbid_copy_engine_usage(self):
|
||||||
|
return self.mode == "cpu"
|
||||||
|
|
||||||
|
|
||||||
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
|
def _hook_module_forward_for_offloader(index, module, offloaders, prefetch_step):
|
||||||
def _on_forward_end():
|
def _on_forward_end():
|
||||||
@@ -398,14 +406,30 @@ class _ShmCpuParamOffloader(_BaseParamOffloader):
|
|||||||
return self.shm_cpu_data.to("cuda", non_blocking=True)
|
return self.shm_cpu_data.to("cuda", non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def update_param(param, new_tensor):
|
||||||
|
"""Update parameter while keeping properties needed by Offloader (e.g. pinned host memory)."""
|
||||||
|
|
||||||
|
if param.device == new_tensor.device:
|
||||||
|
param.data = new_tensor
|
||||||
|
else:
|
||||||
|
assert param.device == torch.device(
|
||||||
|
"cpu"
|
||||||
|
), f"{param.device=} {new_tensor.device=}"
|
||||||
|
param.data = _create_cpu_data(new_tensor, pin_memory=True)
|
||||||
|
|
||||||
|
|
||||||
def _move_param_to_cpu(param, pin_memory: bool):
|
def _move_param_to_cpu(param, pin_memory: bool):
|
||||||
|
param.data = _create_cpu_data(param.data, pin_memory=pin_memory)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_cpu_data(data, pin_memory: bool):
|
||||||
cpu_data = _empty_strided_like(
|
cpu_data = _empty_strided_like(
|
||||||
param.data,
|
data,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
cpu_data.copy_(param.data)
|
cpu_data.copy_(data)
|
||||||
param.data = cpu_data
|
return cpu_data
|
||||||
|
|
||||||
|
|
||||||
def _move_param_to_meta(module, param_name):
|
def _move_param_to_meta(module, param_name):
|
||||||
|
|||||||
Reference in New Issue
Block a user