Support offloading in fp8 (#9948)
This commit is contained in:
@@ -2,6 +2,7 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt import offloader
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
||||
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
|
||||
@@ -417,10 +418,14 @@ def block_quant_dequant(
|
||||
def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
|
||||
assert isinstance(weight, torch.nn.Parameter)
|
||||
assert isinstance(weight_scale_inv, torch.nn.Parameter)
|
||||
weight.data, weight_scale_inv.data = _requant_weight_ue8m0(
|
||||
weight, weight_scale_inv, weight_block_size
|
||||
|
||||
new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
|
||||
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
offloader.update_param(weight, new_weight)
|
||||
weight_scale_inv.data = new_weight_scale_inv
|
||||
|
||||
|
||||
def _requant_weight_ue8m0(
|
||||
weight: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user