unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -54,9 +54,9 @@ from sglang.srt.utils import (
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
from aiter.fused_moe_bf16_asm import asm_moe
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
||||
# Disable marlin for ROCm
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# activation_scheme: dynamic
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
@@ -563,7 +563,7 @@ class Fp8MoEMethod:
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
if (
|
||||
is_hip_
|
||||
_is_hip
|
||||
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
|
||||
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
|
||||
w13_weight_scale1 = torch.nn.Parameter(
|
||||
@@ -630,7 +630,7 @@ class Fp8MoEMethod:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
@@ -667,7 +667,7 @@ class Fp8MoEMethod:
|
||||
# If checkpoint is fp16 or bfloat16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
@@ -689,7 +689,7 @@ class Fp8MoEMethod:
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
@@ -721,7 +721,7 @@ class Fp8MoEMethod:
|
||||
)
|
||||
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = (
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
@@ -771,7 +771,7 @@ class Fp8MoEMethod:
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self.process_weights_hip_scale_padding(layer)
|
||||
return
|
||||
|
||||
@@ -882,7 +882,7 @@ class Fp8MoEMethod:
|
||||
correction_bias=correction_bias,
|
||||
)
|
||||
|
||||
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
||||
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
return asm_moe(
|
||||
@@ -895,7 +895,7 @@ class Fp8MoEMethod:
|
||||
layer.w2_weight_scale1,
|
||||
activation=activation,
|
||||
)
|
||||
if is_hip_ and get_bool_env_var("CK_MOE"):
|
||||
if _is_hip and get_bool_env_var("CK_MOE"):
|
||||
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
|
||||
assert (
|
||||
activation == "silu"
|
||||
|
||||
Reference in New Issue
Block a user