Clean up imports (#5467)

This commit is contained in:
Lianmin Zheng
2025-04-16 15:26:49 -07:00
committed by GitHub
parent d7bc19a46a
commit 177320a582
51 changed files with 376 additions and 573 deletions

View File

@@ -1,18 +1,17 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType
from typing import List, Mapping, Optional, Tuple, Union
from typing import List, Mapping, Tuple, Union
import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
else:
from vllm import _custom_ops as vllm_ops
if not _is_cuda:
from vllm._custom_ops import scaled_fp8_quant
def is_fp8_fnuz() -> bool:
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
if _is_cuda:
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
else:
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
weight_dq, max_w_scale
)
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
start = end
return max_w_scale, weight