[Refactor] Reducing code duplication across FP8 CUDA quantization kernels (#4163)

This commit is contained in:
Stefan He
2025-03-06 22:58:52 -08:00
committed by GitHub
parent c7f254468f
commit 95085d65e9
5 changed files with 32 additions and 64 deletions

View File

@@ -1,13 +1,12 @@
import itertools
import math
from typing import Any, Dict, List, Optional, Tuple
from typing import Tuple
import torch
import triton
import triton.language as tl
from sgl_kernel import sgl_per_token_group_quant_fp8
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn