unify is_cuda and is_hip (#4321)

This commit is contained in:
Yineng Zhang
2025-03-11 18:12:56 -07:00
committed by GitHub
parent 1cf63485c1
commit d1da58e275
18 changed files with 104 additions and 92 deletions

View File

@@ -54,9 +54,9 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_ = is_hip()
_is_hip = is_hip()
if is_hip_:
if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter.ops.shuffle import shuffle_weight
@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm
if is_hip_:
if _is_hip:
self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
@@ -563,7 +563,7 @@ class Fp8MoEMethod:
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if (
is_hip_
_is_hip
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
@@ -630,7 +630,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
@@ -667,7 +667,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
@@ -689,7 +689,7 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip_:
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
@@ -721,7 +721,7 @@ class Fp8MoEMethod:
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
@@ -771,7 +771,7 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False
)
if is_hip_:
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
@@ -882,7 +882,7 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
@@ -895,7 +895,7 @@ class Fp8MoEMethod:
layer.w2_weight_scale1,
activation=activation,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"

View File

@@ -22,12 +22,12 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_cuda = is_cuda()
if _is_cuda:
import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
@@ -332,7 +332,7 @@ def static_quant_fp8(
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)

View File

@@ -17,8 +17,8 @@ from sglang.srt.utils import (
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
_is_hip = is_hip()
if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif is_hip_ and get_bool_env_var("CK_MOE"):
elif _is_hip and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
@@ -142,7 +142,7 @@ def input_to_float8(
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)

View File

@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class W8A8Fp8Config(QuantizationConfig):
"""Config class for W8A8 FP8 Quantization.
@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale.detach()
if is_hip():
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)