Clean up imports (#5467)
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import List, Mapping, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
else:
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
if not _is_cuda:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
|
||||
|
||||
def is_fp8_fnuz() -> bool:
|
||||
@@ -116,12 +115,7 @@ def requantize_with_max_scale(
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
||||
if _is_cuda:
|
||||
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
else:
|
||||
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
||||
weight_dq, max_w_scale
|
||||
)
|
||||
weight[start:end, :], _ = scaled_fp8_quant(weight_dq, max_w_scale)
|
||||
start = end
|
||||
|
||||
return max_w_scale, weight
|
||||
|
||||
Reference in New Issue
Block a user