Support offloading in fp8 (#9948)

This commit is contained in:
fzyzcjy
2025-09-14 16:14:28 +08:00
committed by GitHub
parent b047b553c2
commit fa46e2bd40
4 changed files with 95 additions and 17 deletions

View File

@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt import offloader
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
@@ -417,10 +418,14 @@ def block_quant_dequant(
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
assert isinstance(weight, torch.nn.Parameter)
assert isinstance(weight_scale_inv, torch.nn.Parameter)
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
weight, weight_scale_inv, weight_block_size
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
)
offloader.update_param(weight, new_weight)
weight_scale_inv.data = new_weight_scale_inv
def _requant_weight_ue8m0(
weight: torch.Tensor,